fix
This commit is contained in:
421
venv/lib/python3.11/site-packages/pydantic/__init__.py
Normal file
421
venv/lib/python3.11/site-packages/pydantic/__init__.py
Normal file
@@ -0,0 +1,421 @@
|
||||
import typing
|
||||
from importlib import import_module
|
||||
|
||||
from ._migration import getattr_migration
|
||||
from .version import VERSION
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
# import of virtually everything is supported via `__getattr__` below,
|
||||
# but we need them here for type checking and IDE support
|
||||
import pydantic_core
|
||||
from pydantic_core.core_schema import (
|
||||
FieldSerializationInfo,
|
||||
SerializationInfo,
|
||||
SerializerFunctionWrapHandler,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
)
|
||||
|
||||
from . import dataclasses
|
||||
from ._internal._generate_schema import GenerateSchema as GenerateSchema
|
||||
from .aliases import AliasChoices, AliasGenerator, AliasPath
|
||||
from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
from .config import ConfigDict, with_config
|
||||
from .errors import *
|
||||
from .fields import Field, PrivateAttr, computed_field
|
||||
from .functional_serializers import (
|
||||
PlainSerializer,
|
||||
SerializeAsAny,
|
||||
WrapSerializer,
|
||||
field_serializer,
|
||||
model_serializer,
|
||||
)
|
||||
from .functional_validators import (
|
||||
AfterValidator,
|
||||
BeforeValidator,
|
||||
InstanceOf,
|
||||
PlainValidator,
|
||||
SkipValidation,
|
||||
WrapValidator,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from .json_schema import WithJsonSchema
|
||||
from .main import *
|
||||
from .networks import *
|
||||
from .type_adapter import TypeAdapter
|
||||
from .types import *
|
||||
from .validate_call_decorator import validate_call
|
||||
from .warnings import (
|
||||
PydanticDeprecatedSince20,
|
||||
PydanticDeprecatedSince26,
|
||||
PydanticDeprecatedSince29,
|
||||
PydanticDeprecationWarning,
|
||||
PydanticExperimentalWarning,
|
||||
)
|
||||
|
||||
# this encourages pycharm to import `ValidationError` from here, not pydantic_core
|
||||
ValidationError = pydantic_core.ValidationError
|
||||
from .deprecated.class_validators import root_validator, validator
|
||||
from .deprecated.config import BaseConfig, Extra
|
||||
from .deprecated.tools import *
|
||||
from .root_model import RootModel
|
||||
|
||||
__version__ = VERSION
|
||||
__all__ = (
|
||||
# dataclasses
|
||||
'dataclasses',
|
||||
# functional validators
|
||||
'field_validator',
|
||||
'model_validator',
|
||||
'AfterValidator',
|
||||
'BeforeValidator',
|
||||
'PlainValidator',
|
||||
'WrapValidator',
|
||||
'SkipValidation',
|
||||
'InstanceOf',
|
||||
# JSON Schema
|
||||
'WithJsonSchema',
|
||||
# deprecated V1 functional validators, these are imported via `__getattr__` below
|
||||
'root_validator',
|
||||
'validator',
|
||||
# functional serializers
|
||||
'field_serializer',
|
||||
'model_serializer',
|
||||
'PlainSerializer',
|
||||
'SerializeAsAny',
|
||||
'WrapSerializer',
|
||||
# config
|
||||
'ConfigDict',
|
||||
'with_config',
|
||||
# deprecated V1 config, these are imported via `__getattr__` below
|
||||
'BaseConfig',
|
||||
'Extra',
|
||||
# validate_call
|
||||
'validate_call',
|
||||
# errors
|
||||
'PydanticErrorCodes',
|
||||
'PydanticUserError',
|
||||
'PydanticSchemaGenerationError',
|
||||
'PydanticImportError',
|
||||
'PydanticUndefinedAnnotation',
|
||||
'PydanticInvalidForJsonSchema',
|
||||
# fields
|
||||
'Field',
|
||||
'computed_field',
|
||||
'PrivateAttr',
|
||||
# alias
|
||||
'AliasChoices',
|
||||
'AliasGenerator',
|
||||
'AliasPath',
|
||||
# main
|
||||
'BaseModel',
|
||||
'create_model',
|
||||
# network
|
||||
'AnyUrl',
|
||||
'AnyHttpUrl',
|
||||
'FileUrl',
|
||||
'HttpUrl',
|
||||
'FtpUrl',
|
||||
'WebsocketUrl',
|
||||
'AnyWebsocketUrl',
|
||||
'UrlConstraints',
|
||||
'EmailStr',
|
||||
'NameEmail',
|
||||
'IPvAnyAddress',
|
||||
'IPvAnyInterface',
|
||||
'IPvAnyNetwork',
|
||||
'PostgresDsn',
|
||||
'CockroachDsn',
|
||||
'AmqpDsn',
|
||||
'RedisDsn',
|
||||
'MongoDsn',
|
||||
'KafkaDsn',
|
||||
'NatsDsn',
|
||||
'MySQLDsn',
|
||||
'MariaDBDsn',
|
||||
'ClickHouseDsn',
|
||||
'SnowflakeDsn',
|
||||
'validate_email',
|
||||
# root_model
|
||||
'RootModel',
|
||||
# deprecated tools, these are imported via `__getattr__` below
|
||||
'parse_obj_as',
|
||||
'schema_of',
|
||||
'schema_json_of',
|
||||
# types
|
||||
'Strict',
|
||||
'StrictStr',
|
||||
'conbytes',
|
||||
'conlist',
|
||||
'conset',
|
||||
'confrozenset',
|
||||
'constr',
|
||||
'StringConstraints',
|
||||
'ImportString',
|
||||
'conint',
|
||||
'PositiveInt',
|
||||
'NegativeInt',
|
||||
'NonNegativeInt',
|
||||
'NonPositiveInt',
|
||||
'confloat',
|
||||
'PositiveFloat',
|
||||
'NegativeFloat',
|
||||
'NonNegativeFloat',
|
||||
'NonPositiveFloat',
|
||||
'FiniteFloat',
|
||||
'condecimal',
|
||||
'condate',
|
||||
'UUID1',
|
||||
'UUID3',
|
||||
'UUID4',
|
||||
'UUID5',
|
||||
'FilePath',
|
||||
'DirectoryPath',
|
||||
'NewPath',
|
||||
'Json',
|
||||
'Secret',
|
||||
'SecretStr',
|
||||
'SecretBytes',
|
||||
'StrictBool',
|
||||
'StrictBytes',
|
||||
'StrictInt',
|
||||
'StrictFloat',
|
||||
'PaymentCardNumber',
|
||||
'ByteSize',
|
||||
'PastDate',
|
||||
'FutureDate',
|
||||
'PastDatetime',
|
||||
'FutureDatetime',
|
||||
'AwareDatetime',
|
||||
'NaiveDatetime',
|
||||
'AllowInfNan',
|
||||
'EncoderProtocol',
|
||||
'EncodedBytes',
|
||||
'EncodedStr',
|
||||
'Base64Encoder',
|
||||
'Base64Bytes',
|
||||
'Base64Str',
|
||||
'Base64UrlBytes',
|
||||
'Base64UrlStr',
|
||||
'GetPydanticSchema',
|
||||
'Tag',
|
||||
'Discriminator',
|
||||
'JsonValue',
|
||||
'FailFast',
|
||||
# type_adapter
|
||||
'TypeAdapter',
|
||||
# version
|
||||
'__version__',
|
||||
'VERSION',
|
||||
# warnings
|
||||
'PydanticDeprecatedSince20',
|
||||
'PydanticDeprecatedSince26',
|
||||
'PydanticDeprecatedSince29',
|
||||
'PydanticDeprecationWarning',
|
||||
'PydanticExperimentalWarning',
|
||||
# annotated handlers
|
||||
'GetCoreSchemaHandler',
|
||||
'GetJsonSchemaHandler',
|
||||
# generate schema from ._internal
|
||||
'GenerateSchema',
|
||||
# pydantic_core
|
||||
'ValidationError',
|
||||
'ValidationInfo',
|
||||
'SerializationInfo',
|
||||
'ValidatorFunctionWrapHandler',
|
||||
'FieldSerializationInfo',
|
||||
'SerializerFunctionWrapHandler',
|
||||
'OnErrorOmit',
|
||||
)
|
||||
|
||||
# A mapping of {<member name>: (package, <module name>)} defining dynamic imports
|
||||
_dynamic_imports: 'dict[str, tuple[str, str]]' = {
|
||||
'dataclasses': (__spec__.parent, '__module__'),
|
||||
# functional validators
|
||||
'field_validator': (__spec__.parent, '.functional_validators'),
|
||||
'model_validator': (__spec__.parent, '.functional_validators'),
|
||||
'AfterValidator': (__spec__.parent, '.functional_validators'),
|
||||
'BeforeValidator': (__spec__.parent, '.functional_validators'),
|
||||
'PlainValidator': (__spec__.parent, '.functional_validators'),
|
||||
'WrapValidator': (__spec__.parent, '.functional_validators'),
|
||||
'SkipValidation': (__spec__.parent, '.functional_validators'),
|
||||
'InstanceOf': (__spec__.parent, '.functional_validators'),
|
||||
# JSON Schema
|
||||
'WithJsonSchema': (__spec__.parent, '.json_schema'),
|
||||
# functional serializers
|
||||
'field_serializer': (__spec__.parent, '.functional_serializers'),
|
||||
'model_serializer': (__spec__.parent, '.functional_serializers'),
|
||||
'PlainSerializer': (__spec__.parent, '.functional_serializers'),
|
||||
'SerializeAsAny': (__spec__.parent, '.functional_serializers'),
|
||||
'WrapSerializer': (__spec__.parent, '.functional_serializers'),
|
||||
# config
|
||||
'ConfigDict': (__spec__.parent, '.config'),
|
||||
'with_config': (__spec__.parent, '.config'),
|
||||
# validate call
|
||||
'validate_call': (__spec__.parent, '.validate_call_decorator'),
|
||||
# errors
|
||||
'PydanticErrorCodes': (__spec__.parent, '.errors'),
|
||||
'PydanticUserError': (__spec__.parent, '.errors'),
|
||||
'PydanticSchemaGenerationError': (__spec__.parent, '.errors'),
|
||||
'PydanticImportError': (__spec__.parent, '.errors'),
|
||||
'PydanticUndefinedAnnotation': (__spec__.parent, '.errors'),
|
||||
'PydanticInvalidForJsonSchema': (__spec__.parent, '.errors'),
|
||||
# fields
|
||||
'Field': (__spec__.parent, '.fields'),
|
||||
'computed_field': (__spec__.parent, '.fields'),
|
||||
'PrivateAttr': (__spec__.parent, '.fields'),
|
||||
# alias
|
||||
'AliasChoices': (__spec__.parent, '.aliases'),
|
||||
'AliasGenerator': (__spec__.parent, '.aliases'),
|
||||
'AliasPath': (__spec__.parent, '.aliases'),
|
||||
# main
|
||||
'BaseModel': (__spec__.parent, '.main'),
|
||||
'create_model': (__spec__.parent, '.main'),
|
||||
# network
|
||||
'AnyUrl': (__spec__.parent, '.networks'),
|
||||
'AnyHttpUrl': (__spec__.parent, '.networks'),
|
||||
'FileUrl': (__spec__.parent, '.networks'),
|
||||
'HttpUrl': (__spec__.parent, '.networks'),
|
||||
'FtpUrl': (__spec__.parent, '.networks'),
|
||||
'WebsocketUrl': (__spec__.parent, '.networks'),
|
||||
'AnyWebsocketUrl': (__spec__.parent, '.networks'),
|
||||
'UrlConstraints': (__spec__.parent, '.networks'),
|
||||
'EmailStr': (__spec__.parent, '.networks'),
|
||||
'NameEmail': (__spec__.parent, '.networks'),
|
||||
'IPvAnyAddress': (__spec__.parent, '.networks'),
|
||||
'IPvAnyInterface': (__spec__.parent, '.networks'),
|
||||
'IPvAnyNetwork': (__spec__.parent, '.networks'),
|
||||
'PostgresDsn': (__spec__.parent, '.networks'),
|
||||
'CockroachDsn': (__spec__.parent, '.networks'),
|
||||
'AmqpDsn': (__spec__.parent, '.networks'),
|
||||
'RedisDsn': (__spec__.parent, '.networks'),
|
||||
'MongoDsn': (__spec__.parent, '.networks'),
|
||||
'KafkaDsn': (__spec__.parent, '.networks'),
|
||||
'NatsDsn': (__spec__.parent, '.networks'),
|
||||
'MySQLDsn': (__spec__.parent, '.networks'),
|
||||
'MariaDBDsn': (__spec__.parent, '.networks'),
|
||||
'ClickHouseDsn': (__spec__.parent, '.networks'),
|
||||
'SnowflakeDsn': (__spec__.parent, '.networks'),
|
||||
'validate_email': (__spec__.parent, '.networks'),
|
||||
# root_model
|
||||
'RootModel': (__spec__.parent, '.root_model'),
|
||||
# types
|
||||
'Strict': (__spec__.parent, '.types'),
|
||||
'StrictStr': (__spec__.parent, '.types'),
|
||||
'conbytes': (__spec__.parent, '.types'),
|
||||
'conlist': (__spec__.parent, '.types'),
|
||||
'conset': (__spec__.parent, '.types'),
|
||||
'confrozenset': (__spec__.parent, '.types'),
|
||||
'constr': (__spec__.parent, '.types'),
|
||||
'StringConstraints': (__spec__.parent, '.types'),
|
||||
'ImportString': (__spec__.parent, '.types'),
|
||||
'conint': (__spec__.parent, '.types'),
|
||||
'PositiveInt': (__spec__.parent, '.types'),
|
||||
'NegativeInt': (__spec__.parent, '.types'),
|
||||
'NonNegativeInt': (__spec__.parent, '.types'),
|
||||
'NonPositiveInt': (__spec__.parent, '.types'),
|
||||
'confloat': (__spec__.parent, '.types'),
|
||||
'PositiveFloat': (__spec__.parent, '.types'),
|
||||
'NegativeFloat': (__spec__.parent, '.types'),
|
||||
'NonNegativeFloat': (__spec__.parent, '.types'),
|
||||
'NonPositiveFloat': (__spec__.parent, '.types'),
|
||||
'FiniteFloat': (__spec__.parent, '.types'),
|
||||
'condecimal': (__spec__.parent, '.types'),
|
||||
'condate': (__spec__.parent, '.types'),
|
||||
'UUID1': (__spec__.parent, '.types'),
|
||||
'UUID3': (__spec__.parent, '.types'),
|
||||
'UUID4': (__spec__.parent, '.types'),
|
||||
'UUID5': (__spec__.parent, '.types'),
|
||||
'FilePath': (__spec__.parent, '.types'),
|
||||
'DirectoryPath': (__spec__.parent, '.types'),
|
||||
'NewPath': (__spec__.parent, '.types'),
|
||||
'Json': (__spec__.parent, '.types'),
|
||||
'Secret': (__spec__.parent, '.types'),
|
||||
'SecretStr': (__spec__.parent, '.types'),
|
||||
'SecretBytes': (__spec__.parent, '.types'),
|
||||
'StrictBool': (__spec__.parent, '.types'),
|
||||
'StrictBytes': (__spec__.parent, '.types'),
|
||||
'StrictInt': (__spec__.parent, '.types'),
|
||||
'StrictFloat': (__spec__.parent, '.types'),
|
||||
'PaymentCardNumber': (__spec__.parent, '.types'),
|
||||
'ByteSize': (__spec__.parent, '.types'),
|
||||
'PastDate': (__spec__.parent, '.types'),
|
||||
'FutureDate': (__spec__.parent, '.types'),
|
||||
'PastDatetime': (__spec__.parent, '.types'),
|
||||
'FutureDatetime': (__spec__.parent, '.types'),
|
||||
'AwareDatetime': (__spec__.parent, '.types'),
|
||||
'NaiveDatetime': (__spec__.parent, '.types'),
|
||||
'AllowInfNan': (__spec__.parent, '.types'),
|
||||
'EncoderProtocol': (__spec__.parent, '.types'),
|
||||
'EncodedBytes': (__spec__.parent, '.types'),
|
||||
'EncodedStr': (__spec__.parent, '.types'),
|
||||
'Base64Encoder': (__spec__.parent, '.types'),
|
||||
'Base64Bytes': (__spec__.parent, '.types'),
|
||||
'Base64Str': (__spec__.parent, '.types'),
|
||||
'Base64UrlBytes': (__spec__.parent, '.types'),
|
||||
'Base64UrlStr': (__spec__.parent, '.types'),
|
||||
'GetPydanticSchema': (__spec__.parent, '.types'),
|
||||
'Tag': (__spec__.parent, '.types'),
|
||||
'Discriminator': (__spec__.parent, '.types'),
|
||||
'JsonValue': (__spec__.parent, '.types'),
|
||||
'OnErrorOmit': (__spec__.parent, '.types'),
|
||||
'FailFast': (__spec__.parent, '.types'),
|
||||
# type_adapter
|
||||
'TypeAdapter': (__spec__.parent, '.type_adapter'),
|
||||
# warnings
|
||||
'PydanticDeprecatedSince20': (__spec__.parent, '.warnings'),
|
||||
'PydanticDeprecatedSince26': (__spec__.parent, '.warnings'),
|
||||
'PydanticDeprecatedSince29': (__spec__.parent, '.warnings'),
|
||||
'PydanticDeprecationWarning': (__spec__.parent, '.warnings'),
|
||||
'PydanticExperimentalWarning': (__spec__.parent, '.warnings'),
|
||||
# annotated handlers
|
||||
'GetCoreSchemaHandler': (__spec__.parent, '.annotated_handlers'),
|
||||
'GetJsonSchemaHandler': (__spec__.parent, '.annotated_handlers'),
|
||||
# generate schema from ._internal
|
||||
'GenerateSchema': (__spec__.parent, '._internal._generate_schema'),
|
||||
# pydantic_core stuff
|
||||
'ValidationError': ('pydantic_core', '.'),
|
||||
'ValidationInfo': ('pydantic_core', '.core_schema'),
|
||||
'SerializationInfo': ('pydantic_core', '.core_schema'),
|
||||
'ValidatorFunctionWrapHandler': ('pydantic_core', '.core_schema'),
|
||||
'FieldSerializationInfo': ('pydantic_core', '.core_schema'),
|
||||
'SerializerFunctionWrapHandler': ('pydantic_core', '.core_schema'),
|
||||
# deprecated, mostly not included in __all__
|
||||
'root_validator': (__spec__.parent, '.deprecated.class_validators'),
|
||||
'validator': (__spec__.parent, '.deprecated.class_validators'),
|
||||
'BaseConfig': (__spec__.parent, '.deprecated.config'),
|
||||
'Extra': (__spec__.parent, '.deprecated.config'),
|
||||
'parse_obj_as': (__spec__.parent, '.deprecated.tools'),
|
||||
'schema_of': (__spec__.parent, '.deprecated.tools'),
|
||||
'schema_json_of': (__spec__.parent, '.deprecated.tools'),
|
||||
'FieldValidationInfo': ('pydantic_core', '.core_schema'),
|
||||
}
|
||||
_deprecated_dynamic_imports = {'FieldValidationInfo'}
|
||||
|
||||
_getattr_migration = getattr_migration(__name__)
|
||||
|
||||
|
||||
def __getattr__(attr_name: str) -> object:
|
||||
dynamic_attr = _dynamic_imports.get(attr_name)
|
||||
if dynamic_attr is None:
|
||||
return _getattr_migration(attr_name)
|
||||
|
||||
package, module_name = dynamic_attr
|
||||
|
||||
if module_name == '__module__':
|
||||
result = import_module(f'.{attr_name}', package=package)
|
||||
globals()[attr_name] = result
|
||||
return result
|
||||
else:
|
||||
module = import_module(module_name, package=package)
|
||||
result = getattr(module, attr_name)
|
||||
g = globals()
|
||||
for k, (_, v_module_name) in _dynamic_imports.items():
|
||||
if v_module_name == module_name and k not in _deprecated_dynamic_imports:
|
||||
g[k] = getattr(module, k)
|
||||
return result
|
||||
|
||||
|
||||
def __dir__() -> 'list[str]':
|
||||
return list(__all__)
|
||||
341
venv/lib/python3.11/site-packages/pydantic/_internal/_config.py
Normal file
341
venv/lib/python3.11/site-packages/pydantic/_internal/_config.py
Normal file
@@ -0,0 +1,341 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import (
|
||||
Literal,
|
||||
Self,
|
||||
)
|
||||
|
||||
from ..aliases import AliasGenerator
|
||||
from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable
|
||||
from ..errors import PydanticUserError
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._internal._schema_generation_shared import GenerateSchema
|
||||
from ..fields import ComputedFieldInfo, FieldInfo
|
||||
|
||||
DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.'
|
||||
|
||||
|
||||
class ConfigWrapper:
|
||||
"""Internal wrapper for Config which exposes ConfigDict items as attributes."""
|
||||
|
||||
__slots__ = ('config_dict',)
|
||||
|
||||
config_dict: ConfigDict
|
||||
|
||||
# all annotations are copied directly from ConfigDict, and should be kept up to date, a test will fail if they
|
||||
# stop matching
|
||||
title: str | None
|
||||
str_to_lower: bool
|
||||
str_to_upper: bool
|
||||
str_strip_whitespace: bool
|
||||
str_min_length: int
|
||||
str_max_length: int | None
|
||||
extra: ExtraValues | None
|
||||
frozen: bool
|
||||
populate_by_name: bool
|
||||
use_enum_values: bool
|
||||
validate_assignment: bool
|
||||
arbitrary_types_allowed: bool
|
||||
from_attributes: bool
|
||||
# whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names
|
||||
# to construct error `loc`s, default `True`
|
||||
loc_by_alias: bool
|
||||
alias_generator: Callable[[str], str] | AliasGenerator | None
|
||||
model_title_generator: Callable[[type], str] | None
|
||||
field_title_generator: Callable[[str, FieldInfo | ComputedFieldInfo], str] | None
|
||||
ignored_types: tuple[type, ...]
|
||||
allow_inf_nan: bool
|
||||
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
|
||||
json_encoders: dict[type[object], JsonEncoder] | None
|
||||
|
||||
# new in V2
|
||||
strict: bool
|
||||
# whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never'
|
||||
revalidate_instances: Literal['always', 'never', 'subclass-instances']
|
||||
ser_json_timedelta: Literal['iso8601', 'float']
|
||||
ser_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
val_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
ser_json_inf_nan: Literal['null', 'constants', 'strings']
|
||||
# whether to validate default values during validation, default False
|
||||
validate_default: bool
|
||||
validate_return: bool
|
||||
protected_namespaces: tuple[str, ...]
|
||||
hide_input_in_errors: bool
|
||||
defer_build: bool
|
||||
experimental_defer_build_mode: tuple[Literal['model', 'type_adapter'], ...]
|
||||
plugin_settings: dict[str, object] | None
|
||||
schema_generator: type[GenerateSchema] | None
|
||||
json_schema_serialization_defaults_required: bool
|
||||
json_schema_mode_override: Literal['validation', 'serialization', None]
|
||||
coerce_numbers_to_str: bool
|
||||
regex_engine: Literal['rust-regex', 'python-re']
|
||||
validation_error_cause: bool
|
||||
use_attribute_docstrings: bool
|
||||
cache_strings: bool | Literal['all', 'keys', 'none']
|
||||
|
||||
def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
|
||||
if check:
|
||||
self.config_dict = prepare_config(config)
|
||||
else:
|
||||
self.config_dict = cast(ConfigDict, config)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, bases: tuple[type[Any], ...], namespace: dict[str, Any], kwargs: dict[str, Any]) -> Self:
|
||||
"""Build a new `ConfigWrapper` instance for a `BaseModel`.
|
||||
|
||||
The config wrapper built based on (in descending order of priority):
|
||||
- options from `kwargs`
|
||||
- options from the `namespace`
|
||||
- options from the base classes (`bases`)
|
||||
|
||||
Args:
|
||||
bases: A tuple of base classes.
|
||||
namespace: The namespace of the class being created.
|
||||
kwargs: The kwargs passed to the class being created.
|
||||
|
||||
Returns:
|
||||
A `ConfigWrapper` instance for `BaseModel`.
|
||||
"""
|
||||
config_new = ConfigDict()
|
||||
for base in bases:
|
||||
config = getattr(base, 'model_config', None)
|
||||
if config:
|
||||
config_new.update(config.copy())
|
||||
|
||||
config_class_from_namespace = namespace.get('Config')
|
||||
config_dict_from_namespace = namespace.get('model_config')
|
||||
|
||||
raw_annotations = namespace.get('__annotations__', {})
|
||||
if raw_annotations.get('model_config') and not config_dict_from_namespace:
|
||||
raise PydanticUserError(
|
||||
'`model_config` cannot be used as a model field name. Use `model_config` for model configuration.',
|
||||
code='model-config-invalid-field-name',
|
||||
)
|
||||
|
||||
if config_class_from_namespace and config_dict_from_namespace:
|
||||
raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both')
|
||||
|
||||
config_from_namespace = config_dict_from_namespace or prepare_config(config_class_from_namespace)
|
||||
|
||||
config_new.update(config_from_namespace)
|
||||
|
||||
for k in list(kwargs.keys()):
|
||||
if k in config_keys:
|
||||
config_new[k] = kwargs.pop(k)
|
||||
|
||||
return cls(config_new)
|
||||
|
||||
# we don't show `__getattr__` to type checkers so missing attributes cause errors
|
||||
if not TYPE_CHECKING: # pragma: no branch
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self.config_dict[name]
|
||||
except KeyError:
|
||||
try:
|
||||
return config_defaults[name]
|
||||
except KeyError:
|
||||
raise AttributeError(f'Config has no attribute {name!r}') from None
|
||||
|
||||
def core_config(self, obj: Any) -> core_schema.CoreConfig:
|
||||
"""Create a pydantic-core config, `obj` is just used to populate `title` if not set in config.
|
||||
|
||||
Pass `obj=None` if you do not want to attempt to infer the `title`.
|
||||
|
||||
We don't use getattr here since we don't want to populate with defaults.
|
||||
|
||||
Args:
|
||||
obj: An object used to populate `title` if not set in config.
|
||||
|
||||
Returns:
|
||||
A `CoreConfig` object created from config.
|
||||
"""
|
||||
config = self.config_dict
|
||||
|
||||
core_config_values = {
|
||||
'title': config.get('title') or (obj and obj.__name__),
|
||||
'extra_fields_behavior': config.get('extra'),
|
||||
'allow_inf_nan': config.get('allow_inf_nan'),
|
||||
'populate_by_name': config.get('populate_by_name'),
|
||||
'str_strip_whitespace': config.get('str_strip_whitespace'),
|
||||
'str_to_lower': config.get('str_to_lower'),
|
||||
'str_to_upper': config.get('str_to_upper'),
|
||||
'strict': config.get('strict'),
|
||||
'ser_json_timedelta': config.get('ser_json_timedelta'),
|
||||
'ser_json_bytes': config.get('ser_json_bytes'),
|
||||
'val_json_bytes': config.get('val_json_bytes'),
|
||||
'ser_json_inf_nan': config.get('ser_json_inf_nan'),
|
||||
'from_attributes': config.get('from_attributes'),
|
||||
'loc_by_alias': config.get('loc_by_alias'),
|
||||
'revalidate_instances': config.get('revalidate_instances'),
|
||||
'validate_default': config.get('validate_default'),
|
||||
'str_max_length': config.get('str_max_length'),
|
||||
'str_min_length': config.get('str_min_length'),
|
||||
'hide_input_in_errors': config.get('hide_input_in_errors'),
|
||||
'coerce_numbers_to_str': config.get('coerce_numbers_to_str'),
|
||||
'regex_engine': config.get('regex_engine'),
|
||||
'validation_error_cause': config.get('validation_error_cause'),
|
||||
'cache_strings': config.get('cache_strings'),
|
||||
}
|
||||
|
||||
return core_schema.CoreConfig(**{k: v for k, v in core_config_values.items() if v is not None})
|
||||
|
||||
def __repr__(self):
|
||||
c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items())
|
||||
return f'ConfigWrapper({c})'
|
||||
|
||||
|
||||
class ConfigWrapperStack:
|
||||
"""A stack of `ConfigWrapper` instances."""
|
||||
|
||||
def __init__(self, config_wrapper: ConfigWrapper):
|
||||
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]
|
||||
|
||||
@property
|
||||
def tail(self) -> ConfigWrapper:
|
||||
return self._config_wrapper_stack[-1]
|
||||
|
||||
@contextmanager
|
||||
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
|
||||
if config_wrapper is None:
|
||||
yield
|
||||
return
|
||||
|
||||
if not isinstance(config_wrapper, ConfigWrapper):
|
||||
config_wrapper = ConfigWrapper(config_wrapper, check=False)
|
||||
|
||||
self._config_wrapper_stack.append(config_wrapper)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._config_wrapper_stack.pop()
|
||||
|
||||
|
||||
config_defaults = ConfigDict(
|
||||
title=None,
|
||||
str_to_lower=False,
|
||||
str_to_upper=False,
|
||||
str_strip_whitespace=False,
|
||||
str_min_length=0,
|
||||
str_max_length=None,
|
||||
# let the model / dataclass decide how to handle it
|
||||
extra=None,
|
||||
frozen=False,
|
||||
populate_by_name=False,
|
||||
use_enum_values=False,
|
||||
validate_assignment=False,
|
||||
arbitrary_types_allowed=False,
|
||||
from_attributes=False,
|
||||
loc_by_alias=True,
|
||||
alias_generator=None,
|
||||
model_title_generator=None,
|
||||
field_title_generator=None,
|
||||
ignored_types=(),
|
||||
allow_inf_nan=True,
|
||||
json_schema_extra=None,
|
||||
strict=False,
|
||||
revalidate_instances='never',
|
||||
ser_json_timedelta='iso8601',
|
||||
ser_json_bytes='utf8',
|
||||
val_json_bytes='utf8',
|
||||
ser_json_inf_nan='null',
|
||||
validate_default=False,
|
||||
validate_return=False,
|
||||
protected_namespaces=('model_',),
|
||||
hide_input_in_errors=False,
|
||||
json_encoders=None,
|
||||
defer_build=False,
|
||||
experimental_defer_build_mode=('model',),
|
||||
plugin_settings=None,
|
||||
schema_generator=None,
|
||||
json_schema_serialization_defaults_required=False,
|
||||
json_schema_mode_override=None,
|
||||
coerce_numbers_to_str=False,
|
||||
regex_engine='rust-regex',
|
||||
validation_error_cause=False,
|
||||
use_attribute_docstrings=False,
|
||||
cache_strings=True,
|
||||
)
|
||||
|
||||
|
||||
def prepare_config(config: ConfigDict | dict[str, Any] | type[Any] | None) -> ConfigDict:
|
||||
"""Create a `ConfigDict` instance from an existing dict, a class (e.g. old class-based config) or None.
|
||||
|
||||
Args:
|
||||
config: The input config.
|
||||
|
||||
Returns:
|
||||
A ConfigDict object created from config.
|
||||
"""
|
||||
if config is None:
|
||||
return ConfigDict()
|
||||
|
||||
if not isinstance(config, dict):
|
||||
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
config = {k: getattr(config, k) for k in dir(config) if not k.startswith('__')}
|
||||
|
||||
config_dict = cast(ConfigDict, config)
|
||||
check_deprecated(config_dict)
|
||||
return config_dict
|
||||
|
||||
|
||||
config_keys = set(ConfigDict.__annotations__.keys())
|
||||
|
||||
|
||||
V2_REMOVED_KEYS = {
|
||||
'allow_mutation',
|
||||
'error_msg_templates',
|
||||
'fields',
|
||||
'getter_dict',
|
||||
'smart_union',
|
||||
'underscore_attrs_are_private',
|
||||
'json_loads',
|
||||
'json_dumps',
|
||||
'copy_on_model_validation',
|
||||
'post_init_call',
|
||||
}
|
||||
V2_RENAMED_KEYS = {
|
||||
'allow_population_by_field_name': 'populate_by_name',
|
||||
'anystr_lower': 'str_to_lower',
|
||||
'anystr_strip_whitespace': 'str_strip_whitespace',
|
||||
'anystr_upper': 'str_to_upper',
|
||||
'keep_untouched': 'ignored_types',
|
||||
'max_anystr_length': 'str_max_length',
|
||||
'min_anystr_length': 'str_min_length',
|
||||
'orm_mode': 'from_attributes',
|
||||
'schema_extra': 'json_schema_extra',
|
||||
'validate_all': 'validate_default',
|
||||
}
|
||||
|
||||
|
||||
def check_deprecated(config_dict: ConfigDict) -> None:
|
||||
"""Check for deprecated config keys and warn the user.
|
||||
|
||||
Args:
|
||||
config_dict: The input config.
|
||||
"""
|
||||
deprecated_removed_keys = V2_REMOVED_KEYS & config_dict.keys()
|
||||
deprecated_renamed_keys = V2_RENAMED_KEYS.keys() & config_dict.keys()
|
||||
if deprecated_removed_keys or deprecated_renamed_keys:
|
||||
renamings = {k: V2_RENAMED_KEYS[k] for k in sorted(deprecated_renamed_keys)}
|
||||
renamed_bullets = [f'* {k!r} has been renamed to {v!r}' for k, v in renamings.items()]
|
||||
removed_bullets = [f'* {k!r} has been removed' for k in sorted(deprecated_removed_keys)]
|
||||
message = '\n'.join(['Valid config keys have changed in V2:'] + renamed_bullets + removed_bullets)
|
||||
warnings.warn(message, UserWarning)
|
||||
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import typing
|
||||
from typing import Any, cast
|
||||
|
||||
import typing_extensions
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from pydantic_core import CoreSchema
|
||||
|
||||
from ._schema_generation_shared import (
|
||||
CoreSchemaOrField,
|
||||
GetJsonSchemaFunction,
|
||||
)
|
||||
|
||||
|
||||
class CoreMetadata(typing_extensions.TypedDict, total=False):
|
||||
"""A `TypedDict` for holding the metadata dict of the schema.
|
||||
|
||||
Attributes:
|
||||
pydantic_js_functions: List of JSON schema functions.
|
||||
pydantic_js_prefer_positional_arguments: Whether JSON schema generator will
|
||||
prefer positional over keyword arguments for an 'arguments' schema.
|
||||
"""
|
||||
|
||||
pydantic_js_functions: list[GetJsonSchemaFunction]
|
||||
pydantic_js_annotation_functions: list[GetJsonSchemaFunction]
|
||||
|
||||
# If `pydantic_js_prefer_positional_arguments` is True, the JSON schema generator will
|
||||
# prefer positional over keyword arguments for an 'arguments' schema.
|
||||
pydantic_js_prefer_positional_arguments: bool | None
|
||||
pydantic_js_input_core_schema: CoreSchema | None
|
||||
|
||||
|
||||
class CoreMetadataHandler:
|
||||
"""Because the metadata field in pydantic_core is of type `Dict[str, Any]`, we can't assume much about its contents.
|
||||
|
||||
This class is used to interact with the metadata field on a CoreSchema object in a consistent way throughout pydantic.
|
||||
|
||||
TODO: We'd like to refactor the storage of json related metadata to be more explicit, and less functionally oriented.
|
||||
This should make its way into our v2.10 release. It's inevitable that we need to store some json schema related information
|
||||
on core schemas, given that we generate JSON schemas directly from core schemas. That being said, debugging related
|
||||
issues is quite difficult when JSON schema information is disguised via dynamically defined functions.
|
||||
"""
|
||||
|
||||
__slots__ = ('_schema',)
|
||||
|
||||
def __init__(self, schema: CoreSchemaOrField):
|
||||
self._schema = schema
|
||||
|
||||
metadata = schema.get('metadata')
|
||||
if metadata is None:
|
||||
schema['metadata'] = CoreMetadata() # type: ignore
|
||||
elif not isinstance(metadata, dict):
|
||||
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
|
||||
|
||||
@property
|
||||
def metadata(self) -> CoreMetadata:
|
||||
"""Retrieves the metadata dict from the schema, initializing it to a dict if it is None
|
||||
and raises an error if it is not a dict.
|
||||
"""
|
||||
metadata = self._schema.get('metadata')
|
||||
if metadata is None:
|
||||
self._schema['metadata'] = metadata = CoreMetadata() # type: ignore
|
||||
if not isinstance(metadata, dict):
|
||||
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
|
||||
return cast(CoreMetadata, metadata)
|
||||
|
||||
|
||||
def build_metadata_dict(
|
||||
*, # force keyword arguments to make it easier to modify this signature in a backwards-compatible way
|
||||
js_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
js_prefer_positional_arguments: bool | None = None,
|
||||
js_input_core_schema: CoreSchema | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent with the `CoreMetadataHandler` class."""
|
||||
metadata = CoreMetadata(
|
||||
pydantic_js_functions=js_functions or [],
|
||||
pydantic_js_annotation_functions=js_annotation_functions or [],
|
||||
pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments,
|
||||
pydantic_js_input_core_schema=js_input_core_schema,
|
||||
)
|
||||
return {k: v for k, v in metadata.items() if v is not None}
|
||||
@@ -0,0 +1,570 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Hashable,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic_core import validate_core_schema as _validate_core_schema
|
||||
from typing_extensions import TypeAliasType, TypeGuard, get_args, get_origin
|
||||
|
||||
from . import _repr
|
||||
from ._typing_extra import is_generic_alias
|
||||
|
||||
AnyFunctionSchema = Union[
|
||||
core_schema.AfterValidatorFunctionSchema,
|
||||
core_schema.BeforeValidatorFunctionSchema,
|
||||
core_schema.WrapValidatorFunctionSchema,
|
||||
core_schema.PlainValidatorFunctionSchema,
|
||||
]
|
||||
|
||||
|
||||
FunctionSchemaWithInnerSchema = Union[
|
||||
core_schema.AfterValidatorFunctionSchema,
|
||||
core_schema.BeforeValidatorFunctionSchema,
|
||||
core_schema.WrapValidatorFunctionSchema,
|
||||
]
|
||||
|
||||
CoreSchemaField = Union[
|
||||
core_schema.ModelField, core_schema.DataclassField, core_schema.TypedDictField, core_schema.ComputedField
|
||||
]
|
||||
CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField]
|
||||
|
||||
_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'}
|
||||
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
|
||||
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'set', 'frozenset'}
|
||||
|
||||
TAGGED_UNION_TAG_KEY = 'pydantic.internal.tagged_union_tag'
|
||||
"""
|
||||
Used in a `Tag` schema to specify the tag used for a discriminated union.
|
||||
"""
|
||||
HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid'
|
||||
"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the
|
||||
schema was first encountered.
|
||||
"""
|
||||
|
||||
|
||||
def is_core_schema(
|
||||
schema: CoreSchemaOrField,
|
||||
) -> TypeGuard[CoreSchema]:
|
||||
return schema['type'] not in _CORE_SCHEMA_FIELD_TYPES
|
||||
|
||||
|
||||
def is_core_schema_field(
|
||||
schema: CoreSchemaOrField,
|
||||
) -> TypeGuard[CoreSchemaField]:
|
||||
return schema['type'] in _CORE_SCHEMA_FIELD_TYPES
|
||||
|
||||
|
||||
def is_function_with_inner_schema(
|
||||
schema: CoreSchemaOrField,
|
||||
) -> TypeGuard[FunctionSchemaWithInnerSchema]:
|
||||
return schema['type'] in _FUNCTION_WITH_INNER_SCHEMA_TYPES
|
||||
|
||||
|
||||
def is_list_like_schema_with_items_schema(
|
||||
schema: CoreSchema,
|
||||
) -> TypeGuard[core_schema.ListSchema | core_schema.SetSchema | core_schema.FrozenSetSchema]:
|
||||
return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES
|
||||
|
||||
|
||||
def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str:
|
||||
"""Produces the ref to be used for this type by pydantic_core's core schemas.
|
||||
|
||||
This `args_override` argument was added for the purpose of creating valid recursive references
|
||||
when creating generic models without needing to create a concrete class.
|
||||
"""
|
||||
origin = get_origin(type_) or type_
|
||||
|
||||
args = get_args(type_) if is_generic_alias(type_) else (args_override or ())
|
||||
generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None)
|
||||
if generic_metadata:
|
||||
origin = generic_metadata['origin'] or origin
|
||||
args = generic_metadata['args'] or args
|
||||
|
||||
module_name = getattr(origin, '__module__', '<No __module__>')
|
||||
if isinstance(origin, TypeAliasType):
|
||||
type_ref = f'{module_name}.{origin.__name__}:{id(origin)}'
|
||||
else:
|
||||
try:
|
||||
qualname = getattr(origin, '__qualname__', f'<No __qualname__: {origin}>')
|
||||
except Exception:
|
||||
qualname = getattr(origin, '__qualname__', '<No __qualname__>')
|
||||
type_ref = f'{module_name}.{qualname}:{id(origin)}'
|
||||
|
||||
arg_refs: list[str] = []
|
||||
for arg in args:
|
||||
if isinstance(arg, str):
|
||||
# Handle string literals as a special case; we may be able to remove this special handling if we
|
||||
# wrap them in a ForwardRef at some point.
|
||||
arg_ref = f'{arg}:str-{id(arg)}'
|
||||
else:
|
||||
arg_ref = f'{_repr.display_as_type(arg)}:{id(arg)}'
|
||||
arg_refs.append(arg_ref)
|
||||
if arg_refs:
|
||||
type_ref = f'{type_ref}[{",".join(arg_refs)}]'
|
||||
return type_ref
|
||||
|
||||
|
||||
def get_ref(s: core_schema.CoreSchema) -> None | str:
|
||||
"""Get the ref from the schema if it has one.
|
||||
This exists just for type checking to work correctly.
|
||||
"""
|
||||
return s.get('ref', None)
|
||||
|
||||
|
||||
def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]:
|
||||
defs: dict[str, CoreSchema] = {}
|
||||
|
||||
def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
ref = get_ref(s)
|
||||
if ref:
|
||||
defs[ref] = s
|
||||
return recurse(s, _record_valid_refs)
|
||||
|
||||
walk_core_schema(schema, _record_valid_refs)
|
||||
|
||||
return defs
|
||||
|
||||
|
||||
def define_expected_missing_refs(
|
||||
schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
|
||||
) -> core_schema.CoreSchema | None:
|
||||
if not allowed_missing_refs:
|
||||
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
|
||||
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
|
||||
return None
|
||||
|
||||
refs = collect_definitions(schema).keys()
|
||||
|
||||
expected_missing_refs = allowed_missing_refs.difference(refs)
|
||||
if expected_missing_refs:
|
||||
definitions: list[core_schema.CoreSchema] = [
|
||||
# TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail
|
||||
# Issue: https://github.com/pydantic/pydantic-core/issues/619
|
||||
core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True})
|
||||
for ref in expected_missing_refs
|
||||
]
|
||||
return core_schema.definitions_schema(schema, definitions)
|
||||
return None
|
||||
|
||||
|
||||
def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool:
|
||||
invalid = False
|
||||
|
||||
def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
nonlocal invalid
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata:
|
||||
invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY]
|
||||
return s
|
||||
return recurse(s, _is_schema_valid)
|
||||
|
||||
walk_core_schema(schema, _is_schema_valid)
|
||||
return invalid
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
Recurse = Callable[[core_schema.CoreSchema, 'Walk'], core_schema.CoreSchema]
|
||||
Walk = Callable[[core_schema.CoreSchema, Recurse], core_schema.CoreSchema]
|
||||
|
||||
# TODO: Should we move _WalkCoreSchema into pydantic_core proper?
|
||||
# Issue: https://github.com/pydantic/pydantic-core/issues/615
|
||||
|
||||
|
||||
class _WalkCoreSchema:
|
||||
def __init__(self):
|
||||
self._schema_type_to_method = self._build_schema_type_to_method()
|
||||
|
||||
def _build_schema_type_to_method(self) -> dict[core_schema.CoreSchemaType, Recurse]:
|
||||
mapping: dict[core_schema.CoreSchemaType, Recurse] = {}
|
||||
key: core_schema.CoreSchemaType
|
||||
for key in get_args(core_schema.CoreSchemaType):
|
||||
method_name = f"handle_{key.replace('-', '_')}_schema"
|
||||
mapping[key] = getattr(self, method_name, self._handle_other_schemas)
|
||||
return mapping
|
||||
|
||||
def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
return f(schema, self._walk)
|
||||
|
||||
def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
|
||||
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
|
||||
if ser_schema:
|
||||
schema['serialization'] = self._handle_ser_schemas(ser_schema, f)
|
||||
return schema
|
||||
|
||||
def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
sub_schema = schema.get('schema', None)
|
||||
if sub_schema is not None:
|
||||
schema['schema'] = self.walk(sub_schema, f) # type: ignore
|
||||
return schema
|
||||
|
||||
def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema:
|
||||
schema: core_schema.CoreSchema | None = ser_schema.get('schema', None)
|
||||
return_schema: core_schema.CoreSchema | None = ser_schema.get('return_schema', None)
|
||||
if schema is not None or return_schema is not None:
|
||||
ser_schema = ser_schema.copy()
|
||||
if schema is not None:
|
||||
ser_schema['schema'] = self.walk(schema, f) # type: ignore
|
||||
if return_schema is not None:
|
||||
ser_schema['return_schema'] = self.walk(return_schema, f) # type: ignore
|
||||
return ser_schema
|
||||
|
||||
def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_definitions: list[core_schema.CoreSchema] = []
|
||||
for definition in schema['definitions']:
|
||||
if 'schema_ref' in definition and 'ref' in definition:
|
||||
# This indicates a purposely indirect reference
|
||||
# We want to keep such references around for implications related to JSON schema, etc.:
|
||||
new_definitions.append(definition)
|
||||
# However, we still need to walk the referenced definition:
|
||||
self.walk(definition, f)
|
||||
continue
|
||||
|
||||
updated_definition = self.walk(definition, f)
|
||||
if 'ref' in updated_definition:
|
||||
# If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions
|
||||
# This is most likely to happen due to replacing something with a definition reference, in
|
||||
# which case it should certainly not go in the definitions list
|
||||
new_definitions.append(updated_definition)
|
||||
new_inner_schema = self.walk(schema['schema'], f)
|
||||
|
||||
if not new_definitions and len(schema) == 3:
|
||||
# This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema
|
||||
return new_inner_schema
|
||||
|
||||
new_schema = schema.copy()
|
||||
new_schema['schema'] = new_inner_schema
|
||||
new_schema['definitions'] = new_definitions
|
||||
return new_schema
|
||||
|
||||
def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_tuple_schema(self, schema: core_schema.TupleSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']]
|
||||
return schema
|
||||
|
||||
def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
keys_schema = schema.get('keys_schema')
|
||||
if keys_schema is not None:
|
||||
schema['keys_schema'] = self.walk(keys_schema, f)
|
||||
values_schema = schema.get('values_schema')
|
||||
if values_schema:
|
||||
schema['values_schema'] = self.walk(values_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
if not is_function_with_inner_schema(schema):
|
||||
return schema
|
||||
schema['schema'] = self.walk(schema['schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_choices: list[CoreSchema | tuple[CoreSchema, str]] = []
|
||||
for v in schema['choices']:
|
||||
if isinstance(v, tuple):
|
||||
new_choices.append((self.walk(v[0], f), v[1]))
|
||||
else:
|
||||
new_choices.append(self.walk(v, f))
|
||||
schema['choices'] = new_choices
|
||||
return schema
|
||||
|
||||
def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_choices: dict[Hashable, core_schema.CoreSchema] = {}
|
||||
for k, v in schema['choices'].items():
|
||||
new_choices[k] = v if isinstance(v, (str, int)) else self.walk(v, f)
|
||||
schema['choices'] = new_choices
|
||||
return schema
|
||||
|
||||
def handle_chain_schema(self, schema: core_schema.ChainSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['steps'] = [self.walk(v, f) for v in schema['steps']]
|
||||
return schema
|
||||
|
||||
def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['lax_schema'] = self.walk(schema['lax_schema'], f)
|
||||
schema['strict_schema'] = self.walk(schema['strict_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['json_schema'] = self.walk(schema['json_schema'], f)
|
||||
schema['python_schema'] = self.walk(schema['python_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
replaced_fields: dict[str, core_schema.ModelField] = {}
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
for k, v in schema['fields'].items():
|
||||
replaced_field = v.copy()
|
||||
replaced_field['schema'] = self.walk(v['schema'], f)
|
||||
replaced_fields[k] = replaced_field
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
replaced_fields: dict[str, core_schema.TypedDictField] = {}
|
||||
for k, v in schema['fields'].items():
|
||||
replaced_field = v.copy()
|
||||
replaced_field['schema'] = self.walk(v['schema'], f)
|
||||
replaced_fields[k] = replaced_field
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
replaced_fields: list[core_schema.DataclassField] = []
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
for field in schema['fields']:
|
||||
replaced_field = field.copy()
|
||||
replaced_field['schema'] = self.walk(field['schema'], f)
|
||||
replaced_fields.append(replaced_field)
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
replaced_arguments_schema: list[core_schema.ArgumentsParameter] = []
|
||||
for param in schema['arguments_schema']:
|
||||
replaced_param = param.copy()
|
||||
replaced_param['schema'] = self.walk(param['schema'], f)
|
||||
replaced_arguments_schema.append(replaced_param)
|
||||
schema['arguments_schema'] = replaced_arguments_schema
|
||||
if 'var_args_schema' in schema:
|
||||
schema['var_args_schema'] = self.walk(schema['var_args_schema'], f)
|
||||
if 'var_kwargs_schema' in schema:
|
||||
schema['var_kwargs_schema'] = self.walk(schema['var_kwargs_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_call_schema(self, schema: core_schema.CallSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['arguments_schema'] = self.walk(schema['arguments_schema'], f)
|
||||
if 'return_schema' in schema:
|
||||
schema['return_schema'] = self.walk(schema['return_schema'], f)
|
||||
return schema
|
||||
|
||||
|
||||
_dispatch = _WalkCoreSchema().walk
|
||||
|
||||
|
||||
def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
"""Recursively traverse a CoreSchema.
|
||||
|
||||
Args:
|
||||
schema (core_schema.CoreSchema): The CoreSchema to process, it will not be modified.
|
||||
f (Walk): A function to apply. This function takes two arguments:
|
||||
1. The current CoreSchema that is being processed
|
||||
(not the same one you passed into this function, one level down).
|
||||
2. The "next" `f` to call. This lets you for example use `f=functools.partial(some_method, some_context)`
|
||||
to pass data down the recursive calls without using globals or other mutable state.
|
||||
|
||||
Returns:
|
||||
core_schema.CoreSchema: A processed CoreSchema.
|
||||
"""
|
||||
return f(schema.copy(), _dispatch)
|
||||
|
||||
|
||||
def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
|
||||
definitions: dict[str, core_schema.CoreSchema] = {}
|
||||
ref_counts: dict[str, int] = defaultdict(int)
|
||||
involved_in_recursion: dict[str, bool] = {}
|
||||
current_recursion_ref_count: dict[str, int] = defaultdict(int)
|
||||
|
||||
def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] == 'definitions':
|
||||
for definition in s['definitions']:
|
||||
ref = get_ref(definition)
|
||||
assert ref is not None
|
||||
if ref not in definitions:
|
||||
definitions[ref] = definition
|
||||
recurse(definition, collect_refs)
|
||||
return recurse(s['schema'], collect_refs)
|
||||
else:
|
||||
ref = get_ref(s)
|
||||
if ref is not None:
|
||||
new = recurse(s, collect_refs)
|
||||
new_ref = get_ref(new)
|
||||
if new_ref:
|
||||
definitions[new_ref] = new
|
||||
return core_schema.definition_reference_schema(schema_ref=ref)
|
||||
else:
|
||||
return recurse(s, collect_refs)
|
||||
|
||||
schema = walk_core_schema(schema, collect_refs)
|
||||
|
||||
def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] != 'definition-ref':
|
||||
return recurse(s, count_refs)
|
||||
ref = s['schema_ref']
|
||||
ref_counts[ref] += 1
|
||||
|
||||
if ref_counts[ref] >= 2:
|
||||
# If this model is involved in a recursion this should be detected
|
||||
# on its second encounter, we can safely stop the walk here.
|
||||
if current_recursion_ref_count[ref] != 0:
|
||||
involved_in_recursion[ref] = True
|
||||
return s
|
||||
|
||||
current_recursion_ref_count[ref] += 1
|
||||
recurse(definitions[ref], count_refs)
|
||||
current_recursion_ref_count[ref] -= 1
|
||||
return s
|
||||
|
||||
schema = walk_core_schema(schema, count_refs)
|
||||
|
||||
assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'
|
||||
|
||||
def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
|
||||
if ref_counts[ref] > 1:
|
||||
return False
|
||||
if involved_in_recursion.get(ref, False):
|
||||
return False
|
||||
if 'serialization' in s:
|
||||
return False
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
for k in (
|
||||
'pydantic_js_functions',
|
||||
'pydantic_js_annotation_functions',
|
||||
'pydantic.internal.union_discriminator',
|
||||
):
|
||||
if k in metadata:
|
||||
# we need to keep this as a ref
|
||||
return False
|
||||
return True
|
||||
|
||||
def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] == 'definition-ref':
|
||||
ref = s['schema_ref']
|
||||
# Check if the reference is only used once, not involved in recursion and does not have
|
||||
# any extra keys (like 'serialization')
|
||||
if can_be_inlined(s, ref):
|
||||
# Inline the reference by replacing the reference with the actual schema
|
||||
new = definitions.pop(ref)
|
||||
ref_counts[ref] -= 1 # because we just replaced it!
|
||||
# put all other keys that were on the def-ref schema into the inlined version
|
||||
# in particular this is needed for `serialization`
|
||||
if 'serialization' in s:
|
||||
new['serialization'] = s['serialization']
|
||||
s = recurse(new, inline_refs)
|
||||
return s
|
||||
else:
|
||||
return recurse(s, inline_refs)
|
||||
else:
|
||||
return recurse(s, inline_refs)
|
||||
|
||||
schema = walk_core_schema(schema, inline_refs)
|
||||
|
||||
def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore
|
||||
|
||||
if def_values:
|
||||
schema = core_schema.definitions_schema(schema=schema, definitions=def_values)
|
||||
return schema
|
||||
|
||||
|
||||
def _strip_metadata(schema: CoreSchema) -> CoreSchema:
|
||||
def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
|
||||
s = s.copy()
|
||||
s.pop('metadata', None)
|
||||
if s['type'] == 'model-fields':
|
||||
s = s.copy()
|
||||
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
|
||||
for field_name, field_schema in s['fields'].items():
|
||||
field_schema.pop('metadata', None)
|
||||
s['fields'][field_name] = field_schema
|
||||
computed_fields = s.get('computed_fields', None)
|
||||
if computed_fields:
|
||||
s['computed_fields'] = [cf.copy() for cf in computed_fields]
|
||||
for cf in computed_fields:
|
||||
cf.pop('metadata', None)
|
||||
else:
|
||||
s.pop('computed_fields', None)
|
||||
elif s['type'] == 'model':
|
||||
# remove some defaults
|
||||
if s.get('custom_init', True) is False:
|
||||
s.pop('custom_init')
|
||||
if s.get('root_model', True) is False:
|
||||
s.pop('root_model')
|
||||
if {'title'}.issuperset(s.get('config', {}).keys()):
|
||||
s.pop('config', None)
|
||||
|
||||
return recurse(s, strip_metadata)
|
||||
|
||||
return walk_core_schema(schema, strip_metadata)
|
||||
|
||||
|
||||
def pretty_print_core_schema(
|
||||
schema: CoreSchema,
|
||||
include_metadata: bool = False,
|
||||
) -> None:
|
||||
"""Pretty print a CoreSchema using rich.
|
||||
This is intended for debugging purposes.
|
||||
|
||||
Args:
|
||||
schema: The CoreSchema to print.
|
||||
include_metadata: Whether to include metadata in the output. Defaults to `False`.
|
||||
"""
|
||||
from rich import print # type: ignore # install it manually in your dev env
|
||||
|
||||
if not include_metadata:
|
||||
schema = _strip_metadata(schema)
|
||||
|
||||
return print(schema)
|
||||
|
||||
|
||||
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
|
||||
if 'PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS' in os.environ:
|
||||
return schema
|
||||
return _validate_core_schema(schema)
|
||||
@@ -0,0 +1,232 @@
|
||||
"""Private logic for creating pydantic dataclasses."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import typing
|
||||
import warnings
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, ClassVar
|
||||
|
||||
from pydantic_core import (
|
||||
ArgsKwargs,
|
||||
SchemaSerializer,
|
||||
SchemaValidator,
|
||||
core_schema,
|
||||
)
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from ..errors import PydanticUndefinedAnnotation
|
||||
from ..plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
from . import _config, _decorators, _typing_extra
|
||||
from ._fields import collect_dataclass_fields
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._generics import get_standard_typevars_map
|
||||
from ._mock_val_ser import set_dataclass_mocks
|
||||
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
|
||||
from ._signature import generate_pydantic_signature
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..config import ConfigDict
|
||||
from ..fields import FieldInfo
|
||||
|
||||
class StandardDataclass(typing.Protocol):
|
||||
__dataclass_fields__: ClassVar[dict[str, Any]]
|
||||
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
|
||||
__post_init__: ClassVar[Callable[..., None]]
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
class PydanticDataclass(StandardDataclass, typing.Protocol):
|
||||
"""A protocol containing attributes only available once a class has been decorated as a Pydantic dataclass.
|
||||
|
||||
Attributes:
|
||||
__pydantic_config__: Pydantic-specific configuration settings for the dataclass.
|
||||
__pydantic_complete__: Whether dataclass building is completed, or if there are still undefined fields.
|
||||
__pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
|
||||
__pydantic_decorators__: Metadata containing the decorators defined on the dataclass.
|
||||
__pydantic_fields__: Metadata about the fields defined on the dataclass.
|
||||
__pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the dataclass.
|
||||
__pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the dataclass.
|
||||
"""
|
||||
|
||||
__pydantic_config__: ClassVar[ConfigDict]
|
||||
__pydantic_complete__: ClassVar[bool]
|
||||
__pydantic_core_schema__: ClassVar[core_schema.CoreSchema]
|
||||
__pydantic_decorators__: ClassVar[_decorators.DecoratorInfos]
|
||||
__pydantic_fields__: ClassVar[dict[str, FieldInfo]]
|
||||
__pydantic_serializer__: ClassVar[SchemaSerializer]
|
||||
__pydantic_validator__: ClassVar[SchemaValidator | PluggableSchemaValidator]
|
||||
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
|
||||
def set_dataclass_fields(
|
||||
cls: type[StandardDataclass],
|
||||
types_namespace: dict[str, Any] | None = None,
|
||||
config_wrapper: _config.ConfigWrapper | None = None,
|
||||
) -> None:
|
||||
"""Collect and set `cls.__pydantic_fields__`.
|
||||
|
||||
Args:
|
||||
cls: The class.
|
||||
types_namespace: The types namespace, defaults to `None`.
|
||||
config_wrapper: The config wrapper instance, defaults to `None`.
|
||||
"""
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map, config_wrapper=config_wrapper)
|
||||
|
||||
cls.__pydantic_fields__ = fields # type: ignore
|
||||
|
||||
|
||||
def complete_dataclass(
|
||||
cls: type[Any],
|
||||
config_wrapper: _config.ConfigWrapper,
|
||||
*,
|
||||
raise_errors: bool = True,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
) -> bool:
|
||||
"""Finish building a pydantic dataclass.
|
||||
|
||||
This logic is called on a class which has already been wrapped in `dataclasses.dataclass()`.
|
||||
|
||||
This is somewhat analogous to `pydantic._internal._model_construction.complete_model_class`.
|
||||
|
||||
Args:
|
||||
cls: The class.
|
||||
config_wrapper: The config wrapper instance.
|
||||
raise_errors: Whether to raise errors, defaults to `True`.
|
||||
types_namespace: The types namespace.
|
||||
|
||||
Returns:
|
||||
`True` if building a pydantic dataclass is successfully completed, `False` otherwise.
|
||||
|
||||
Raises:
|
||||
PydanticUndefinedAnnotation: If `raise_error` is `True` and there is an undefined annotations.
|
||||
"""
|
||||
if hasattr(cls, '__post_init_post_parse__'):
|
||||
warnings.warn(
|
||||
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
|
||||
)
|
||||
|
||||
if types_namespace is None:
|
||||
types_namespace = _typing_extra.merge_cls_and_parent_ns(cls)
|
||||
|
||||
set_dataclass_fields(cls, types_namespace, config_wrapper=config_wrapper)
|
||||
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
types_namespace,
|
||||
typevars_map,
|
||||
)
|
||||
|
||||
# This needs to be called before we change the __init__
|
||||
sig = generate_pydantic_signature(
|
||||
init=cls.__init__,
|
||||
fields=cls.__pydantic_fields__, # type: ignore
|
||||
config_wrapper=config_wrapper,
|
||||
is_dataclass=True,
|
||||
)
|
||||
|
||||
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
|
||||
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
|
||||
__tracebackhide__ = True
|
||||
s = __dataclass_self__
|
||||
s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)
|
||||
|
||||
__init__.__qualname__ = f'{cls.__qualname__}.__init__'
|
||||
|
||||
cls.__init__ = __init__ # type: ignore
|
||||
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore
|
||||
cls.__signature__ = sig # type: ignore
|
||||
get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None)
|
||||
try:
|
||||
if get_core_schema:
|
||||
schema = get_core_schema(
|
||||
cls,
|
||||
CallbackGetCoreSchemaHandler(
|
||||
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
|
||||
gen_schema,
|
||||
ref_mode='unpack',
|
||||
),
|
||||
)
|
||||
else:
|
||||
schema = gen_schema.generate_schema(cls, from_dunder_get_core_schema=False)
|
||||
except PydanticUndefinedAnnotation as e:
|
||||
if raise_errors:
|
||||
raise
|
||||
set_dataclass_mocks(cls, cls.__name__, f'`{e.name}`')
|
||||
return False
|
||||
|
||||
core_config = config_wrapper.core_config(cls)
|
||||
|
||||
try:
|
||||
schema = gen_schema.clean_schema(schema)
|
||||
except gen_schema.CollectedInvalid:
|
||||
set_dataclass_mocks(cls, cls.__name__, 'all referenced types')
|
||||
return False
|
||||
|
||||
# We are about to set all the remaining required properties expected for this cast;
|
||||
# __pydantic_decorators__ and __pydantic_fields__ should already be set
|
||||
cls = typing.cast('type[PydanticDataclass]', cls)
|
||||
# debug(schema)
|
||||
|
||||
cls.__pydantic_core_schema__ = schema
|
||||
cls.__pydantic_validator__ = validator = create_schema_validator(
|
||||
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
|
||||
)
|
||||
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
|
||||
|
||||
if config_wrapper.validate_assignment:
|
||||
|
||||
@wraps(cls.__setattr__)
|
||||
def validated_setattr(instance: Any, field: str, value: str, /) -> None:
|
||||
validator.validate_assignment(instance, field, value)
|
||||
|
||||
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
|
||||
|
||||
cls.__pydantic_complete__ = True
|
||||
return True
|
||||
|
||||
|
||||
def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.
|
||||
|
||||
We check that
|
||||
- `_cls` is a dataclass
|
||||
- `_cls` does not inherit from a processed pydantic dataclass (and thus have a `__pydantic_validator__`)
|
||||
- `_cls` does not have any annotations that are not dataclass fields
|
||||
e.g.
|
||||
```py
|
||||
import dataclasses
|
||||
|
||||
import pydantic.dataclasses
|
||||
|
||||
@dataclasses.dataclass
|
||||
class A:
|
||||
x: int
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class B(A):
|
||||
y: int
|
||||
```
|
||||
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
|
||||
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
|
||||
|
||||
Args:
|
||||
cls: The class.
|
||||
|
||||
Returns:
|
||||
`True` if the class is a stdlib dataclass, `False` otherwise.
|
||||
"""
|
||||
return (
|
||||
dataclasses.is_dataclass(_cls)
|
||||
and not hasattr(_cls, '__pydantic_validator__')
|
||||
and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
|
||||
)
|
||||
@@ -0,0 +1,827 @@
|
||||
"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, partial, partialmethod
|
||||
from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, TypeVar, Union
|
||||
|
||||
from pydantic_core import PydanticUndefined, core_schema
|
||||
from typing_extensions import Literal, TypeAlias, is_typeddict
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from ._core_utils import get_type_ref
|
||||
from ._internal_dataclass import slots_true
|
||||
from ._typing_extra import get_function_type_hints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..fields import ComputedFieldInfo
|
||||
from ..functional_validators import FieldValidatorModes
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class ValidatorDecoratorInfo:
|
||||
"""A container for data from `@validator` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@validator'.
|
||||
fields: A tuple of field names the validator should be called on.
|
||||
mode: The proposed validator mode.
|
||||
each_item: For complex objects (sets, lists etc.) whether to validate individual
|
||||
elements rather than the whole object.
|
||||
always: Whether this method and other validators should be called even if the value is missing.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@validator'
|
||||
|
||||
fields: tuple[str, ...]
|
||||
mode: Literal['before', 'after']
|
||||
each_item: bool
|
||||
always: bool
|
||||
check_fields: bool | None
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class FieldValidatorDecoratorInfo:
|
||||
"""A container for data from `@field_validator` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@field_validator'.
|
||||
fields: A tuple of field names the validator should be called on.
|
||||
mode: The proposed validator mode.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate
|
||||
the appropriate JSON Schema (in validation mode) and can only specified
|
||||
when `mode` is either `'before'`, `'plain'` or `'wrap'`.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@field_validator'
|
||||
|
||||
fields: tuple[str, ...]
|
||||
mode: FieldValidatorModes
|
||||
check_fields: bool | None
|
||||
json_schema_input_type: Any
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class RootValidatorDecoratorInfo:
|
||||
"""A container for data from `@root_validator` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@root_validator'.
|
||||
mode: The proposed validator mode.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@root_validator'
|
||||
mode: Literal['before', 'after']
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class FieldSerializerDecoratorInfo:
|
||||
"""A container for data from `@field_serializer` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@field_serializer'.
|
||||
fields: A tuple of field names the serializer should be called on.
|
||||
mode: The proposed serializer mode.
|
||||
return_type: The type of the serializer's return value.
|
||||
when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
|
||||
and `'json-unless-none'`.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@field_serializer'
|
||||
fields: tuple[str, ...]
|
||||
mode: Literal['plain', 'wrap']
|
||||
return_type: Any
|
||||
when_used: core_schema.WhenUsed
|
||||
check_fields: bool | None
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class ModelSerializerDecoratorInfo:
|
||||
"""A container for data from `@model_serializer` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@model_serializer'.
|
||||
mode: The proposed serializer mode.
|
||||
return_type: The type of the serializer's return value.
|
||||
when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
|
||||
and `'json-unless-none'`.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@model_serializer'
|
||||
mode: Literal['plain', 'wrap']
|
||||
return_type: Any
|
||||
when_used: core_schema.WhenUsed
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class ModelValidatorDecoratorInfo:
|
||||
"""A container for data from `@model_validator` so that we can access it
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@model_validator'.
|
||||
mode: The proposed serializer mode.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@model_validator'
|
||||
mode: Literal['wrap', 'before', 'after']
|
||||
|
||||
|
||||
DecoratorInfo: TypeAlias = """Union[
|
||||
ValidatorDecoratorInfo,
|
||||
FieldValidatorDecoratorInfo,
|
||||
RootValidatorDecoratorInfo,
|
||||
FieldSerializerDecoratorInfo,
|
||||
ModelSerializerDecoratorInfo,
|
||||
ModelValidatorDecoratorInfo,
|
||||
ComputedFieldInfo,
|
||||
]"""
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
DecoratedType: TypeAlias = (
|
||||
'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]'
|
||||
)
|
||||
|
||||
|
||||
@dataclass # can't use slots here since we set attributes on `__post_init__`
|
||||
class PydanticDescriptorProxy(Generic[ReturnType]):
|
||||
"""Wrap a classmethod, staticmethod, property or unbound function
|
||||
and act as a descriptor that allows us to detect decorated items
|
||||
from the class' attributes.
|
||||
|
||||
This class' __get__ returns the wrapped item's __get__ result,
|
||||
which makes it transparent for classmethods and staticmethods.
|
||||
|
||||
Attributes:
|
||||
wrapped: The decorator that has to be wrapped.
|
||||
decorator_info: The decorator info.
|
||||
shim: A wrapper function to wrap V1 style function.
|
||||
"""
|
||||
|
||||
wrapped: DecoratedType[ReturnType]
|
||||
decorator_info: DecoratorInfo
|
||||
shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
for attr in 'setter', 'deleter':
|
||||
if hasattr(self.wrapped, attr):
|
||||
f = partial(self._call_wrapped_attr, name=attr)
|
||||
setattr(self, attr, f)
|
||||
|
||||
def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
|
||||
self.wrapped = getattr(self.wrapped, name)(func)
|
||||
if isinstance(self.wrapped, property):
|
||||
# update ComputedFieldInfo.wrapped_property
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
if isinstance(self.decorator_info, ComputedFieldInfo):
|
||||
self.decorator_info.wrapped_property = self.wrapped
|
||||
return self
|
||||
|
||||
def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
|
||||
try:
|
||||
return self.wrapped.__get__(obj, obj_type)
|
||||
except AttributeError:
|
||||
# not a descriptor, e.g. a partial object
|
||||
return self.wrapped # type: ignore[return-value]
|
||||
|
||||
def __set_name__(self, instance: Any, name: str) -> None:
|
||||
if hasattr(self.wrapped, '__set_name__'):
|
||||
self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
|
||||
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
"""Forward checks for __isabstractmethod__ and such."""
|
||||
return getattr(self.wrapped, __name)
|
||||
|
||||
|
||||
DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class Decorator(Generic[DecoratorInfoType]):
|
||||
"""A generic container class to join together the decorator metadata
|
||||
(metadata from decorator itself, which we have when the
|
||||
decorator is called but not when we are building the core-schema)
|
||||
and the bound function (which we have after the class itself is created).
|
||||
|
||||
Attributes:
|
||||
cls_ref: The class ref.
|
||||
cls_var_name: The decorated function name.
|
||||
func: The decorated function.
|
||||
shim: A wrapper function to wrap V1 style function.
|
||||
info: The decorator info.
|
||||
"""
|
||||
|
||||
cls_ref: str
|
||||
cls_var_name: str
|
||||
func: Callable[..., Any]
|
||||
shim: Callable[[Any], Any] | None
|
||||
info: DecoratorInfoType
|
||||
|
||||
@staticmethod
|
||||
def build(
|
||||
cls_: Any,
|
||||
*,
|
||||
cls_var_name: str,
|
||||
shim: Callable[[Any], Any] | None,
|
||||
info: DecoratorInfoType,
|
||||
) -> Decorator[DecoratorInfoType]:
|
||||
"""Build a new decorator.
|
||||
|
||||
Args:
|
||||
cls_: The class.
|
||||
cls_var_name: The decorated function name.
|
||||
shim: A wrapper function to wrap V1 style function.
|
||||
info: The decorator info.
|
||||
|
||||
Returns:
|
||||
The new decorator instance.
|
||||
"""
|
||||
func = get_attribute_from_bases(cls_, cls_var_name)
|
||||
if shim is not None:
|
||||
func = shim(func)
|
||||
func = unwrap_wrapped_function(func, unwrap_partial=False)
|
||||
if not callable(func):
|
||||
# This branch will get hit for classmethod properties
|
||||
attribute = get_attribute_from_base_dicts(cls_, cls_var_name) # prevents the binding call to `__get__`
|
||||
if isinstance(attribute, PydanticDescriptorProxy):
|
||||
func = unwrap_wrapped_function(attribute.wrapped)
|
||||
return Decorator(
|
||||
cls_ref=get_type_ref(cls_),
|
||||
cls_var_name=cls_var_name,
|
||||
func=func,
|
||||
shim=shim,
|
||||
info=info,
|
||||
)
|
||||
|
||||
def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]:
|
||||
"""Bind the decorator to a class.
|
||||
|
||||
Args:
|
||||
cls: the class.
|
||||
|
||||
Returns:
|
||||
The new decorator instance.
|
||||
"""
|
||||
return self.build(
|
||||
cls,
|
||||
cls_var_name=self.cls_var_name,
|
||||
shim=self.shim,
|
||||
info=self.info,
|
||||
)
|
||||
|
||||
|
||||
def get_bases(tp: type[Any]) -> tuple[type[Any], ...]:
|
||||
"""Get the base classes of a class or typeddict.
|
||||
|
||||
Args:
|
||||
tp: The type or class to get the bases.
|
||||
|
||||
Returns:
|
||||
The base classes.
|
||||
"""
|
||||
if is_typeddict(tp):
|
||||
return tp.__orig_bases__ # type: ignore
|
||||
try:
|
||||
return tp.__bases__
|
||||
except AttributeError:
|
||||
return ()
|
||||
|
||||
|
||||
def mro(tp: type[Any]) -> tuple[type[Any], ...]:
|
||||
"""Calculate the Method Resolution Order of bases using the C3 algorithm.
|
||||
|
||||
See https://www.python.org/download/releases/2.3/mro/
|
||||
"""
|
||||
# try to use the existing mro, for performance mainly
|
||||
# but also because it helps verify the implementation below
|
||||
if not is_typeddict(tp):
|
||||
try:
|
||||
return tp.__mro__
|
||||
except AttributeError:
|
||||
# GenericAlias and some other cases
|
||||
pass
|
||||
|
||||
bases = get_bases(tp)
|
||||
return (tp,) + mro_for_bases(bases)
|
||||
|
||||
|
||||
def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]:
|
||||
def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
|
||||
while True:
|
||||
non_empty = [seq for seq in seqs if seq]
|
||||
if not non_empty:
|
||||
# Nothing left to process, we're done.
|
||||
return
|
||||
candidate: type[Any] | None = None
|
||||
for seq in non_empty: # Find merge candidates among seq heads.
|
||||
candidate = seq[0]
|
||||
not_head = [s for s in non_empty if candidate in islice(s, 1, None)]
|
||||
if not_head:
|
||||
# Reject the candidate.
|
||||
candidate = None
|
||||
else:
|
||||
break
|
||||
if not candidate:
|
||||
raise TypeError('Inconsistent hierarchy, no C3 MRO is possible')
|
||||
yield candidate
|
||||
for seq in non_empty:
|
||||
# Remove candidate.
|
||||
if seq[0] == candidate:
|
||||
seq.popleft()
|
||||
|
||||
seqs = [deque(mro(base)) for base in bases] + [deque(bases)]
|
||||
return tuple(merge_seqs(seqs))
|
||||
|
||||
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any:
|
||||
"""Get the attribute from the next class in the MRO that has it,
|
||||
aiming to simulate calling the method on the actual class.
|
||||
|
||||
The reason for iterating over the mro instead of just getting
|
||||
the attribute (which would do that for us) is to support TypedDict,
|
||||
which lacks a real __mro__, but can have a virtual one constructed
|
||||
from its bases (as done here).
|
||||
|
||||
Args:
|
||||
tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes.
|
||||
name: The name of the attribute to retrieve.
|
||||
|
||||
Returns:
|
||||
Any: The attribute value, if found.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the attribute is not found in any class in the MRO.
|
||||
"""
|
||||
if isinstance(tp, tuple):
|
||||
for base in mro_for_bases(tp):
|
||||
attribute = base.__dict__.get(name, _sentinel)
|
||||
if attribute is not _sentinel:
|
||||
attribute_get = getattr(attribute, '__get__', None)
|
||||
if attribute_get is not None:
|
||||
return attribute_get(None, tp)
|
||||
return attribute
|
||||
raise AttributeError(f'{name} not found in {tp}')
|
||||
else:
|
||||
try:
|
||||
return getattr(tp, name)
|
||||
except AttributeError:
|
||||
return get_attribute_from_bases(mro(tp), name)
|
||||
|
||||
|
||||
def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any:
|
||||
"""Get an attribute out of the `__dict__` following the MRO.
|
||||
This prevents the call to `__get__` on the descriptor, and allows
|
||||
us to get the original function for classmethod properties.
|
||||
|
||||
Args:
|
||||
tp: The type or class to search for the attribute.
|
||||
name: The name of the attribute to retrieve.
|
||||
|
||||
Returns:
|
||||
Any: The attribute value, if found.
|
||||
|
||||
Raises:
|
||||
KeyError: If the attribute is not found in any class's `__dict__` in the MRO.
|
||||
"""
|
||||
for base in reversed(mro(tp)):
|
||||
if name in base.__dict__:
|
||||
return base.__dict__[name]
|
||||
return tp.__dict__[name] # raise the error
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class DecoratorInfos:
|
||||
"""Mapping of name in the class namespace to decorator info.
|
||||
|
||||
note that the name in the class namespace is the function or attribute name
|
||||
not the field name!
|
||||
"""
|
||||
|
||||
validators: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict)
|
||||
field_validators: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict)
|
||||
root_validators: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict)
|
||||
field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict)
|
||||
model_serializers: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict)
|
||||
model_validators: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict)
|
||||
computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict)
|
||||
|
||||
@staticmethod
|
||||
def build(model_dc: type[Any]) -> DecoratorInfos: # noqa: C901 (ignore complexity)
|
||||
"""We want to collect all DecFunc instances that exist as
|
||||
attributes in the namespace of the class (a BaseModel or dataclass)
|
||||
that called us
|
||||
But we want to collect these in the order of the bases
|
||||
So instead of getting them all from the leaf class (the class that called us),
|
||||
we traverse the bases from root (the oldest ancestor class) to leaf
|
||||
and collect all of the instances as we go, taking care to replace
|
||||
any duplicate ones with the last one we see to mimic how function overriding
|
||||
works with inheritance.
|
||||
If we do replace any functions we put the replacement into the position
|
||||
the replaced function was in; that is, we maintain the order.
|
||||
"""
|
||||
# reminder: dicts are ordered and replacement does not alter the order
|
||||
res = DecoratorInfos()
|
||||
for base in reversed(mro(model_dc)[1:]):
|
||||
existing: DecoratorInfos | None = base.__dict__.get('__pydantic_decorators__')
|
||||
if existing is None:
|
||||
existing = DecoratorInfos.build(base)
|
||||
res.validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.validators.items()})
|
||||
res.field_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_validators.items()})
|
||||
res.root_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.root_validators.items()})
|
||||
res.field_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_serializers.items()})
|
||||
res.model_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_serializers.items()})
|
||||
res.model_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_validators.items()})
|
||||
res.computed_fields.update({k: v.bind_to_cls(model_dc) for k, v in existing.computed_fields.items()})
|
||||
|
||||
to_replace: list[tuple[str, Any]] = []
|
||||
|
||||
for var_name, var_value in vars(model_dc).items():
|
||||
if isinstance(var_value, PydanticDescriptorProxy):
|
||||
info = var_value.decorator_info
|
||||
if isinstance(info, ValidatorDecoratorInfo):
|
||||
res.validators[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
elif isinstance(info, FieldValidatorDecoratorInfo):
|
||||
res.field_validators[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
elif isinstance(info, RootValidatorDecoratorInfo):
|
||||
res.root_validators[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
elif isinstance(info, FieldSerializerDecoratorInfo):
|
||||
# check whether a serializer function is already registered for fields
|
||||
for field_serializer_decorator in res.field_serializers.values():
|
||||
# check that each field has at most one serializer function.
|
||||
# serializer functions for the same field in subclasses are allowed,
|
||||
# and are treated as overrides
|
||||
if field_serializer_decorator.cls_var_name == var_name:
|
||||
continue
|
||||
for f in info.fields:
|
||||
if f in field_serializer_decorator.info.fields:
|
||||
raise PydanticUserError(
|
||||
'Multiple field serializer functions were defined '
|
||||
f'for field {f!r}, this is not allowed.',
|
||||
code='multiple-field-serializers',
|
||||
)
|
||||
res.field_serializers[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
elif isinstance(info, ModelValidatorDecoratorInfo):
|
||||
res.model_validators[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
elif isinstance(info, ModelSerializerDecoratorInfo):
|
||||
res.model_serializers[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
else:
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
isinstance(var_value, ComputedFieldInfo)
|
||||
res.computed_fields[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=None, info=info
|
||||
)
|
||||
to_replace.append((var_name, var_value.wrapped))
|
||||
if to_replace:
|
||||
# If we can save `__pydantic_decorators__` on the class we'll be able to check for it above
|
||||
# so then we don't need to re-process the type, which means we can discard our descriptor wrappers
|
||||
# and replace them with the thing they are wrapping (see the other setattr call below)
|
||||
# which allows validator class methods to also function as regular class methods
|
||||
model_dc.__pydantic_decorators__ = res
|
||||
for name, value in to_replace:
|
||||
setattr(model_dc, name, value)
|
||||
return res
|
||||
|
||||
|
||||
def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool:
|
||||
"""Look at a field or model validator function and determine whether it takes an info argument.
|
||||
|
||||
An error is raised if the function has an invalid signature.
|
||||
|
||||
Args:
|
||||
validator: The validator function to inspect.
|
||||
mode: The proposed validator mode.
|
||||
|
||||
Returns:
|
||||
Whether the validator takes an info argument.
|
||||
"""
|
||||
try:
|
||||
sig = signature(validator)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present:
|
||||
return False
|
||||
n_positional = count_positional_required_params(sig)
|
||||
if mode == 'wrap':
|
||||
if n_positional == 3:
|
||||
return True
|
||||
elif n_positional == 2:
|
||||
return False
|
||||
else:
|
||||
assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain"
|
||||
if n_positional == 2:
|
||||
return True
|
||||
elif n_positional == 1:
|
||||
return False
|
||||
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}',
|
||||
code='validator-signature',
|
||||
)
|
||||
|
||||
|
||||
def inspect_field_serializer(
|
||||
serializer: Callable[..., Any], mode: Literal['plain', 'wrap'], computed_field: bool = False
|
||||
) -> tuple[bool, bool]:
|
||||
"""Look at a field serializer function and determine if it is a field serializer,
|
||||
and whether it takes an info argument.
|
||||
|
||||
An error is raised if the function has an invalid signature.
|
||||
|
||||
Args:
|
||||
serializer: The serializer function to inspect.
|
||||
mode: The serializer mode, either 'plain' or 'wrap'.
|
||||
computed_field: When serializer is applied on computed_field. It doesn't require
|
||||
info signature.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_field_serializer, info_arg).
|
||||
"""
|
||||
try:
|
||||
sig = signature(serializer)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present and this is not a method:
|
||||
return (False, False)
|
||||
|
||||
first = next(iter(sig.parameters.values()), None)
|
||||
is_field_serializer = first is not None and first.name == 'self'
|
||||
|
||||
n_positional = count_positional_required_params(sig)
|
||||
if is_field_serializer:
|
||||
# -1 to correct for self parameter
|
||||
info_arg = _serializer_info_arg(mode, n_positional - 1)
|
||||
else:
|
||||
info_arg = _serializer_info_arg(mode, n_positional)
|
||||
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
code='field-serializer-signature',
|
||||
)
|
||||
if info_arg and computed_field:
|
||||
raise PydanticUserError(
|
||||
'field_serializer on computed_field does not use info signature', code='field-serializer-signature'
|
||||
)
|
||||
|
||||
else:
|
||||
return is_field_serializer, info_arg
|
||||
|
||||
|
||||
def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
|
||||
"""Look at a serializer function used via `Annotated` and determine whether it takes an info argument.
|
||||
|
||||
An error is raised if the function has an invalid signature.
|
||||
|
||||
Args:
|
||||
serializer: The serializer function to check.
|
||||
mode: The serializer mode, either 'plain' or 'wrap'.
|
||||
|
||||
Returns:
|
||||
info_arg
|
||||
"""
|
||||
try:
|
||||
sig = signature(serializer)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present:
|
||||
return False
|
||||
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
code='field-serializer-signature',
|
||||
)
|
||||
else:
|
||||
return info_arg
|
||||
|
||||
|
||||
def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
|
||||
"""Look at a model serializer function and determine whether it takes an info argument.
|
||||
|
||||
An error is raised if the function has an invalid signature.
|
||||
|
||||
Args:
|
||||
serializer: The serializer function to check.
|
||||
mode: The serializer mode, either 'plain' or 'wrap'.
|
||||
|
||||
Returns:
|
||||
`info_arg` - whether the function expects an info argument.
|
||||
"""
|
||||
if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer):
|
||||
raise PydanticUserError(
|
||||
'`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method'
|
||||
)
|
||||
|
||||
sig = signature(serializer)
|
||||
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
code='model-serializer-signature',
|
||||
)
|
||||
else:
|
||||
return info_arg
|
||||
|
||||
|
||||
def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
|
||||
if mode == 'plain':
|
||||
if n_positional == 1:
|
||||
# (input_value: Any, /) -> Any
|
||||
return False
|
||||
elif n_positional == 2:
|
||||
# (model: Any, input_value: Any, /) -> Any
|
||||
return True
|
||||
else:
|
||||
assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
|
||||
if n_positional == 2:
|
||||
# (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any
|
||||
return False
|
||||
elif n_positional == 3:
|
||||
# (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any
|
||||
return True
|
||||
|
||||
return None
|
||||
|
||||
|
||||
AnyDecoratorCallable: TypeAlias = (
|
||||
'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]'
|
||||
)
|
||||
|
||||
|
||||
def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool:
|
||||
"""Whether the function is an instance method.
|
||||
|
||||
It will consider a function as instance method if the first parameter of
|
||||
function is `self`.
|
||||
|
||||
Args:
|
||||
function: The function to check.
|
||||
|
||||
Returns:
|
||||
`True` if the function is an instance method, `False` otherwise.
|
||||
"""
|
||||
sig = signature(unwrap_wrapped_function(function))
|
||||
first = next(iter(sig.parameters.values()), None)
|
||||
if first and first.name == 'self':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any:
|
||||
"""Apply the `@classmethod` decorator on the function.
|
||||
|
||||
Args:
|
||||
function: The function to apply the decorator on.
|
||||
|
||||
Return:
|
||||
The `@classmethod` decorator applied function.
|
||||
"""
|
||||
if not isinstance(
|
||||
unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod
|
||||
) and _is_classmethod_from_sig(function):
|
||||
return classmethod(function) # type: ignore[arg-type]
|
||||
return function
|
||||
|
||||
|
||||
def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool:
|
||||
sig = signature(unwrap_wrapped_function(function))
|
||||
first = next(iter(sig.parameters.values()), None)
|
||||
if first and first.name == 'cls':
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def unwrap_wrapped_function(
|
||||
func: Any,
|
||||
*,
|
||||
unwrap_partial: bool = True,
|
||||
unwrap_class_static_method: bool = True,
|
||||
) -> Any:
|
||||
"""Recursively unwraps a wrapped function until the underlying function is reached.
|
||||
This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod.
|
||||
|
||||
Args:
|
||||
func: The function to unwrap.
|
||||
unwrap_partial: If True (default), unwrap partial and partialmethod decorators.
|
||||
unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
|
||||
decorators. If False, only unwrap partial and partialmethod decorators.
|
||||
|
||||
Returns:
|
||||
The underlying function of the wrapped function.
|
||||
"""
|
||||
# Define the types we want to check against as a single tuple.
|
||||
unwrap_types = (
|
||||
(property, cached_property)
|
||||
+ ((partial, partialmethod) if unwrap_partial else ())
|
||||
+ ((staticmethod, classmethod) if unwrap_class_static_method else ())
|
||||
)
|
||||
|
||||
while isinstance(func, unwrap_types):
|
||||
if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
|
||||
func = func.__func__
|
||||
elif isinstance(func, (partial, partialmethod)):
|
||||
func = func.func
|
||||
elif isinstance(func, property):
|
||||
func = func.fget # arbitrary choice, convenient for computed fields
|
||||
else:
|
||||
# Make coverage happy as it can only get here in the last possible case
|
||||
assert isinstance(func, cached_property)
|
||||
func = func.func # type: ignore
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def get_function_return_type(
|
||||
func: Any, explicit_return_type: Any, types_namespace: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Get the function return type.
|
||||
|
||||
It gets the return type from the type annotation if `explicit_return_type` is `None`.
|
||||
Otherwise, it returns `explicit_return_type`.
|
||||
|
||||
Args:
|
||||
func: The function to get its return type.
|
||||
explicit_return_type: The explicit return type.
|
||||
types_namespace: The types namespace, defaults to `None`.
|
||||
|
||||
Returns:
|
||||
The function return type.
|
||||
"""
|
||||
if explicit_return_type is PydanticUndefined:
|
||||
# try to get it from the type annotation
|
||||
hints = get_function_type_hints(
|
||||
unwrap_wrapped_function(func), include_keys={'return'}, types_namespace=types_namespace
|
||||
)
|
||||
return hints.get('return', PydanticUndefined)
|
||||
else:
|
||||
return explicit_return_type
|
||||
|
||||
|
||||
def count_positional_required_params(sig: Signature) -> int:
|
||||
"""Get the number of positional (required) arguments of a signature.
|
||||
|
||||
This function should only be used to inspect signatures of validation and serialization functions.
|
||||
The first argument (the value being serialized or validated) is counted as a required argument
|
||||
even if a default value exists.
|
||||
|
||||
Returns:
|
||||
The number of positional arguments of a signature.
|
||||
"""
|
||||
parameters = list(sig.parameters.values())
|
||||
return sum(
|
||||
1
|
||||
for param in parameters
|
||||
if can_be_positional(param)
|
||||
# First argument is the value being validated/serialized, and can have a default value
|
||||
# (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg
|
||||
# for instance) should be required, and thus without any default value.
|
||||
and (param.default is Parameter.empty or param == parameters[0])
|
||||
)
|
||||
|
||||
|
||||
def can_be_positional(param: Parameter) -> bool:
|
||||
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
||||
|
||||
|
||||
def ensure_property(f: Any) -> Any:
|
||||
"""Ensure that a function is a `property` or `cached_property`, or is a valid descriptor.
|
||||
|
||||
Args:
|
||||
f: The function to check.
|
||||
|
||||
Returns:
|
||||
The function, or a `property` or `cached_property` instance wrapping the function.
|
||||
"""
|
||||
if ismethoddescriptor(f) or isdatadescriptor(f):
|
||||
return f
|
||||
else:
|
||||
return property(f)
|
||||
@@ -0,0 +1,174 @@
|
||||
"""Logic for V1 validators, e.g. `@validator` and `@root_validator`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from inspect import Parameter, signature
|
||||
from typing import Any, Dict, Tuple, Union, cast
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from ._decorators import can_be_positional
|
||||
|
||||
|
||||
class V1OnlyValueValidator(Protocol):
|
||||
"""A simple validator, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValues(Protocol):
|
||||
"""A validator with `values` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValuesKwOnly(Protocol):
|
||||
"""A validator with keyword only `values` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithKwargs(Protocol):
|
||||
"""A validator with `kwargs` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValuesAndKwargs(Protocol):
|
||||
"""A validator with `values` and `kwargs` arguments, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
V1Validator = Union[
|
||||
V1ValidatorWithValues, V1ValidatorWithValuesKwOnly, V1ValidatorWithKwargs, V1ValidatorWithValuesAndKwargs
|
||||
]
|
||||
|
||||
|
||||
def can_be_keyword(param: Parameter) -> bool:
|
||||
return param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
|
||||
|
||||
|
||||
def make_generic_v1_field_validator(validator: V1Validator) -> core_schema.WithInfoValidatorFunction:
|
||||
"""Wrap a V1 style field validator for V2 compatibility.
|
||||
|
||||
Args:
|
||||
validator: The V1 style field validator.
|
||||
|
||||
Returns:
|
||||
A wrapped V2 style field validator.
|
||||
|
||||
Raises:
|
||||
PydanticUserError: If the signature is not supported or the parameters are
|
||||
not available in Pydantic V2.
|
||||
"""
|
||||
sig = signature(validator)
|
||||
|
||||
needs_values_kw = False
|
||||
|
||||
for param_num, (param_name, parameter) in enumerate(sig.parameters.items()):
|
||||
if can_be_keyword(parameter) and param_name in ('field', 'config'):
|
||||
raise PydanticUserError(
|
||||
'The `field` and `config` parameters are not available in Pydantic V2, '
|
||||
'please use the `info` parameter instead.',
|
||||
code='validator-field-config-info',
|
||||
)
|
||||
if parameter.kind is Parameter.VAR_KEYWORD:
|
||||
needs_values_kw = True
|
||||
elif can_be_keyword(parameter) and param_name == 'values':
|
||||
needs_values_kw = True
|
||||
elif can_be_positional(parameter) and param_num == 0:
|
||||
# value
|
||||
continue
|
||||
elif parameter.default is Parameter.empty: # ignore params with defaults e.g. bound by functools.partial
|
||||
raise PydanticUserError(
|
||||
f'Unsupported signature for V1 style validator {validator}: {sig} is not supported.',
|
||||
code='validator-v1-signature',
|
||||
)
|
||||
|
||||
if needs_values_kw:
|
||||
# (v, **kwargs), (v, values, **kwargs), (v, *, values, **kwargs) or (v, *, values)
|
||||
val1 = cast(V1ValidatorWithValues, validator)
|
||||
|
||||
def wrapper1(value: Any, info: core_schema.ValidationInfo) -> Any:
|
||||
return val1(value, values=info.data)
|
||||
|
||||
return wrapper1
|
||||
else:
|
||||
val2 = cast(V1OnlyValueValidator, validator)
|
||||
|
||||
def wrapper2(value: Any, _: core_schema.ValidationInfo) -> Any:
|
||||
return val2(value)
|
||||
|
||||
return wrapper2
|
||||
|
||||
|
||||
RootValidatorValues = Dict[str, Any]
|
||||
# technically tuple[model_dict, model_extra, fields_set] | tuple[dataclass_dict, init_vars]
|
||||
RootValidatorFieldsTuple = Tuple[Any, ...]
|
||||
|
||||
|
||||
class V1RootValidatorFunction(Protocol):
|
||||
"""A simple root validator, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues: ...
|
||||
|
||||
|
||||
class V2CoreBeforeRootValidator(Protocol):
|
||||
"""V2 validator with mode='before'."""
|
||||
|
||||
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues: ...
|
||||
|
||||
|
||||
class V2CoreAfterRootValidator(Protocol):
|
||||
"""V2 validator with mode='after'."""
|
||||
|
||||
def __call__(
|
||||
self, __fields_tuple: RootValidatorFieldsTuple, __info: core_schema.ValidationInfo
|
||||
) -> RootValidatorFieldsTuple: ...
|
||||
|
||||
|
||||
def make_v1_generic_root_validator(
|
||||
validator: V1RootValidatorFunction, pre: bool
|
||||
) -> V2CoreBeforeRootValidator | V2CoreAfterRootValidator:
|
||||
"""Wrap a V1 style root validator for V2 compatibility.
|
||||
|
||||
Args:
|
||||
validator: The V1 style field validator.
|
||||
pre: Whether the validator is a pre validator.
|
||||
|
||||
Returns:
|
||||
A wrapped V2 style validator.
|
||||
"""
|
||||
if pre is True:
|
||||
# mode='before' for pydantic-core
|
||||
def _wrapper1(values: RootValidatorValues, _: core_schema.ValidationInfo) -> RootValidatorValues:
|
||||
return validator(values)
|
||||
|
||||
return _wrapper1
|
||||
|
||||
# mode='after' for pydantic-core
|
||||
def _wrapper2(fields_tuple: RootValidatorFieldsTuple, _: core_schema.ValidationInfo) -> RootValidatorFieldsTuple:
|
||||
if len(fields_tuple) == 2:
|
||||
# dataclass, this is easy
|
||||
values, init_vars = fields_tuple
|
||||
values = validator(values)
|
||||
return values, init_vars
|
||||
else:
|
||||
# ugly hack: to match v1 behaviour, we merge values and model_extra, then split them up based on fields
|
||||
# afterwards
|
||||
model_dict, model_extra, fields_set = fields_tuple
|
||||
if model_extra:
|
||||
fields = set(model_dict.keys())
|
||||
model_dict.update(model_extra)
|
||||
model_dict_new = validator(model_dict)
|
||||
for k in list(model_dict_new.keys()):
|
||||
if k not in fields:
|
||||
model_extra[k] = model_dict_new.pop(k)
|
||||
else:
|
||||
model_dict_new = validator(model_dict)
|
||||
return model_dict_new, model_extra, fields_set
|
||||
|
||||
return _wrapper2
|
||||
@@ -0,0 +1,503 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Hashable, Sequence
|
||||
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from . import _core_utils
|
||||
from ._core_utils import (
|
||||
CoreSchemaField,
|
||||
collect_definitions,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..types import Discriminator
|
||||
|
||||
CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'
|
||||
|
||||
|
||||
class MissingDefinitionForUnionRef(Exception):
|
||||
"""Raised when applying a discriminated union discriminator to a schema
|
||||
requires a definition that is not yet defined
|
||||
"""
|
||||
|
||||
def __init__(self, ref: str) -> None:
|
||||
self.ref = ref
|
||||
super().__init__(f'Missing definition for ref {self.ref!r}')
|
||||
|
||||
|
||||
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
|
||||
schema.setdefault('metadata', {})
|
||||
metadata = schema.get('metadata')
|
||||
assert metadata is not None
|
||||
metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator
|
||||
|
||||
|
||||
def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
# We recursively walk through the `schema` passed to `apply_discriminators`, applying discriminators
|
||||
# where necessary at each level. During this recursion, we allow references to be resolved from the definitions
|
||||
# that are originally present on the original, outermost `schema`. Before `apply_discriminators` is called,
|
||||
# `simplify_schema_references` is called on the schema (in the `clean_schema` function),
|
||||
# which often puts the definitions in the outermost schema.
|
||||
global_definitions: dict[str, CoreSchema] = collect_definitions(schema)
|
||||
|
||||
def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
|
||||
nonlocal global_definitions
|
||||
|
||||
s = recurse(s, inner)
|
||||
if s['type'] == 'tagged-union':
|
||||
return s
|
||||
|
||||
metadata = s.get('metadata', {})
|
||||
discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
|
||||
if discriminator is not None:
|
||||
s = apply_discriminator(s, discriminator, global_definitions)
|
||||
return s
|
||||
|
||||
return _core_utils.walk_core_schema(schema, inner)
|
||||
|
||||
|
||||
def apply_discriminator(
|
||||
schema: core_schema.CoreSchema,
|
||||
discriminator: str | Discriminator,
|
||||
definitions: dict[str, core_schema.CoreSchema] | None = None,
|
||||
) -> core_schema.CoreSchema:
|
||||
"""Applies the discriminator and returns a new core schema.
|
||||
|
||||
Args:
|
||||
schema: The input schema.
|
||||
discriminator: The name of the field which will serve as the discriminator.
|
||||
definitions: A mapping of schema ref to schema.
|
||||
|
||||
Returns:
|
||||
The new core schema.
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
- If `discriminator` is used with invalid union variant.
|
||||
- If `discriminator` is used with `Union` type with one variant.
|
||||
- If `discriminator` value mapped to multiple choices.
|
||||
MissingDefinitionForUnionRef:
|
||||
If the definition for ref is missing.
|
||||
PydanticUserError:
|
||||
- If a model in union doesn't have a discriminator field.
|
||||
- If discriminator field has a non-string alias.
|
||||
- If discriminator fields have different aliases.
|
||||
- If discriminator field not of type `Literal`.
|
||||
"""
|
||||
from ..types import Discriminator
|
||||
|
||||
if isinstance(discriminator, Discriminator):
|
||||
if isinstance(discriminator.discriminator, str):
|
||||
discriminator = discriminator.discriminator
|
||||
else:
|
||||
return discriminator._convert_schema(schema)
|
||||
|
||||
return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
|
||||
|
||||
|
||||
class _ApplyInferredDiscriminator:
|
||||
"""This class is used to convert an input schema containing a union schema into one where that union is
|
||||
replaced with a tagged-union, with all the associated debugging and performance benefits.
|
||||
|
||||
This is done by:
|
||||
* Validating that the input schema is compatible with the provided discriminator
|
||||
* Introspecting the schema to determine which discriminator values should map to which union choices
|
||||
* Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more
|
||||
|
||||
I have chosen to implement the conversion algorithm in this class, rather than a function,
|
||||
to make it easier to maintain state while recursively walking the provided CoreSchema.
|
||||
"""
|
||||
|
||||
def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]):
|
||||
# `discriminator` should be the name of the field which will serve as the discriminator.
|
||||
# It must be the python name of the field, and *not* the field's alias. Note that as of now,
|
||||
# all members of a discriminated union _must_ use a field with the same name as the discriminator.
|
||||
# This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices.
|
||||
self.discriminator = discriminator
|
||||
|
||||
# `definitions` should contain a mapping of schema ref to schema for all schemas which might
|
||||
# be referenced by some choice
|
||||
self.definitions = definitions
|
||||
|
||||
# `_discriminator_alias` will hold the value, if present, of the alias for the discriminator
|
||||
#
|
||||
# Note: following the v1 implementation, we currently disallow the use of different aliases
|
||||
# for different choices. This is not a limitation of pydantic_core, but if we try to handle
|
||||
# this, the inference logic gets complicated very quickly, and could result in confusing
|
||||
# debugging challenges for users making subtle mistakes.
|
||||
#
|
||||
# Rather than trying to do the most powerful inference possible, I think we should eventually
|
||||
# expose a way to more-manually control the way the TaggedUnionSchema is constructed through
|
||||
# the use of a new type which would be placed as an Annotation on the Union type. This would
|
||||
# provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for
|
||||
# more complex cases, without over-complicating the inference logic for the common cases.
|
||||
self._discriminator_alias: str | None = None
|
||||
|
||||
# `_should_be_nullable` indicates whether the converted union has `None` as an allowed value.
|
||||
# If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while
|
||||
# constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True.
|
||||
# Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure
|
||||
# that the final schema gets wrapped as a NullableSchema. This has the same semantics on the
|
||||
# python side, but resolves the issue that `None` cannot correspond to any discriminator values.
|
||||
self._should_be_nullable = False
|
||||
|
||||
# `_is_nullable` is used to track if the final produced schema will definitely be nullable;
|
||||
# we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved
|
||||
# as an indication that, even if None is discovered as one of the union choices, we will not need to wrap
|
||||
# the final value in another nullable schema.
|
||||
#
|
||||
# This is more complicated than just checking for the final outermost schema having type 'nullable' thanks
|
||||
# to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc.
|
||||
self._is_nullable = False
|
||||
|
||||
# `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices
|
||||
# from the union in the wrapped schema will be appended to this list, and the recursive choice-handling
|
||||
# algorithm may add more choices to this stack as (nested) unions are encountered.
|
||||
self._choices_to_handle: list[core_schema.CoreSchema] = []
|
||||
|
||||
# `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included
|
||||
# in the output TaggedUnionSchema that will replace the union from the input schema
|
||||
self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
|
||||
|
||||
# `_used` is changed to True after applying the discriminator to prevent accidental re-use
|
||||
self._used = False
|
||||
|
||||
def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
"""Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided
|
||||
to this class.
|
||||
|
||||
Args:
|
||||
schema: The input schema.
|
||||
|
||||
Returns:
|
||||
The new core schema.
|
||||
|
||||
Raises:
|
||||
TypeError:
|
||||
- If `discriminator` is used with invalid union variant.
|
||||
- If `discriminator` is used with `Union` type with one variant.
|
||||
- If `discriminator` value mapped to multiple choices.
|
||||
ValueError:
|
||||
If the definition for ref is missing.
|
||||
PydanticUserError:
|
||||
- If a model in union doesn't have a discriminator field.
|
||||
- If discriminator field has a non-string alias.
|
||||
- If discriminator fields have different aliases.
|
||||
- If discriminator field not of type `Literal`.
|
||||
"""
|
||||
assert not self._used
|
||||
schema = self._apply_to_root(schema)
|
||||
if self._should_be_nullable and not self._is_nullable:
|
||||
schema = core_schema.nullable_schema(schema)
|
||||
self._used = True
|
||||
return schema
|
||||
|
||||
def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
"""This method handles the outer-most stage of recursion over the input schema:
|
||||
unwrapping nullable or definitions schemas, and calling the `_handle_choice`
|
||||
method iteratively on the choices extracted (recursively) from the possibly-wrapped union.
|
||||
"""
|
||||
if schema['type'] == 'nullable':
|
||||
self._is_nullable = True
|
||||
wrapped = self._apply_to_root(schema['schema'])
|
||||
nullable_wrapper = schema.copy()
|
||||
nullable_wrapper['schema'] = wrapped
|
||||
return nullable_wrapper
|
||||
|
||||
if schema['type'] == 'definitions':
|
||||
wrapped = self._apply_to_root(schema['schema'])
|
||||
definitions_wrapper = schema.copy()
|
||||
definitions_wrapper['schema'] = wrapped
|
||||
return definitions_wrapper
|
||||
|
||||
if schema['type'] != 'union':
|
||||
# If the schema is not a union, it probably means it just had a single member and
|
||||
# was flattened by pydantic_core.
|
||||
# However, it still may make sense to apply the discriminator to this schema,
|
||||
# as a way to get discriminated-union-style error messages, so we allow this here.
|
||||
schema = core_schema.union_schema([schema])
|
||||
|
||||
# Reverse the choices list before extending the stack so that they get handled in the order they occur
|
||||
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]]
|
||||
self._choices_to_handle.extend(choices_schemas)
|
||||
while self._choices_to_handle:
|
||||
choice = self._choices_to_handle.pop()
|
||||
self._handle_choice(choice)
|
||||
|
||||
if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator:
|
||||
# * We need to annotate `discriminator` as a union here to handle both branches of this conditional
|
||||
# * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the
|
||||
# invariance of list, and because list[list[str | int]] is the type of the discriminator argument
|
||||
# to tagged_union_schema below
|
||||
# * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to
|
||||
# interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here
|
||||
# is the appropriate way to provide a list of fallback attributes to check for a discriminator value.)
|
||||
discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]]
|
||||
else:
|
||||
discriminator = self.discriminator
|
||||
return core_schema.tagged_union_schema(
|
||||
choices=self._tagged_union_choices,
|
||||
discriminator=discriminator,
|
||||
custom_error_type=schema.get('custom_error_type'),
|
||||
custom_error_message=schema.get('custom_error_message'),
|
||||
custom_error_context=schema.get('custom_error_context'),
|
||||
strict=False,
|
||||
from_attributes=True,
|
||||
ref=schema.get('ref'),
|
||||
metadata=schema.get('metadata'),
|
||||
serialization=schema.get('serialization'),
|
||||
)
|
||||
|
||||
def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
|
||||
"""This method handles the "middle" stage of recursion over the input schema.
|
||||
Specifically, it is responsible for handling each choice of the outermost union
|
||||
(and any "coalesced" choices obtained from inner unions).
|
||||
|
||||
Here, "handling" entails:
|
||||
* Coalescing nested unions and compatible tagged-unions
|
||||
* Tracking the presence of 'none' and 'nullable' schemas occurring as choices
|
||||
* Validating that each allowed discriminator value maps to a unique choice
|
||||
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
|
||||
"""
|
||||
if choice['type'] == 'definition-ref':
|
||||
if choice['schema_ref'] not in self.definitions:
|
||||
raise MissingDefinitionForUnionRef(choice['schema_ref'])
|
||||
|
||||
if choice['type'] == 'none':
|
||||
self._should_be_nullable = True
|
||||
elif choice['type'] == 'definitions':
|
||||
self._handle_choice(choice['schema'])
|
||||
elif choice['type'] == 'nullable':
|
||||
self._should_be_nullable = True
|
||||
self._handle_choice(choice['schema']) # unwrap the nullable schema
|
||||
elif choice['type'] == 'union':
|
||||
# Reverse the choices list before extending the stack so that they get handled in the order they occur
|
||||
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
|
||||
self._choices_to_handle.extend(choices_schemas)
|
||||
elif choice['type'] not in {
|
||||
'model',
|
||||
'typed-dict',
|
||||
'tagged-union',
|
||||
'lax-or-strict',
|
||||
'dataclass',
|
||||
'dataclass-args',
|
||||
'definition-ref',
|
||||
} and not _core_utils.is_function_with_inner_schema(choice):
|
||||
# We should eventually handle 'definition-ref' as well
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
else:
|
||||
if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
|
||||
# In this case, this inner tagged-union is compatible with the outer tagged-union,
|
||||
# and its choices can be coalesced into the outer TaggedUnionSchema.
|
||||
subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
|
||||
# Reverse the choices list before extending the stack so that they get handled in the order they occur
|
||||
self._choices_to_handle.extend(subchoices[::-1])
|
||||
return
|
||||
|
||||
inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None)
|
||||
self._set_unique_choice_for_values(choice, inferred_discriminator_values)
|
||||
|
||||
def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool:
|
||||
"""This method returns a boolean indicating whether the discriminator for the `choice`
|
||||
is the same as that being used for the outermost tagged union. This is used to
|
||||
determine whether this TaggedUnionSchema choice should be "coalesced" into the top level,
|
||||
or whether it should be treated as a separate (nested) choice.
|
||||
"""
|
||||
inner_discriminator = choice['discriminator']
|
||||
return inner_discriminator == self.discriminator or (
|
||||
isinstance(inner_discriminator, list)
|
||||
and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator)
|
||||
)
|
||||
|
||||
def _infer_discriminator_values_for_choice( # noqa C901
|
||||
self, choice: core_schema.CoreSchema, source_name: str | None
|
||||
) -> list[str | int]:
|
||||
"""This function recurses over `choice`, extracting all discriminator values that should map to this choice.
|
||||
|
||||
`model_name` is accepted for the purpose of producing useful error messages.
|
||||
"""
|
||||
if choice['type'] == 'definitions':
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
|
||||
elif choice['type'] == 'function-plain':
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
elif _core_utils.is_function_with_inner_schema(choice):
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
|
||||
elif choice['type'] == 'lax-or-strict':
|
||||
return sorted(
|
||||
set(
|
||||
self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None)
|
||||
+ self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None)
|
||||
)
|
||||
)
|
||||
|
||||
elif choice['type'] == 'tagged-union':
|
||||
values: list[str | int] = []
|
||||
# Ignore str/int "choices" since these are just references to other choices
|
||||
subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
|
||||
for subchoice in subchoices:
|
||||
subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None)
|
||||
values.extend(subchoice_values)
|
||||
return values
|
||||
|
||||
elif choice['type'] == 'union':
|
||||
values = []
|
||||
for subchoice in choice['choices']:
|
||||
subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice
|
||||
subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None)
|
||||
values.extend(subchoice_values)
|
||||
return values
|
||||
|
||||
elif choice['type'] == 'nullable':
|
||||
self._should_be_nullable = True
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None)
|
||||
|
||||
elif choice['type'] == 'model':
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
|
||||
|
||||
elif choice['type'] == 'dataclass':
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
|
||||
|
||||
elif choice['type'] == 'model-fields':
|
||||
return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name)
|
||||
|
||||
elif choice['type'] == 'dataclass-args':
|
||||
return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name)
|
||||
|
||||
elif choice['type'] == 'typed-dict':
|
||||
return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name)
|
||||
|
||||
elif choice['type'] == 'definition-ref':
|
||||
schema_ref = choice['schema_ref']
|
||||
if schema_ref not in self.definitions:
|
||||
raise MissingDefinitionForUnionRef(schema_ref)
|
||||
return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
|
||||
def _infer_discriminator_values_for_typed_dict_choice(
|
||||
self, choice: core_schema.TypedDictSchema, source_name: str | None = None
|
||||
) -> list[str | int]:
|
||||
"""This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema
|
||||
for the sake of readability.
|
||||
"""
|
||||
source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}'
|
||||
field = choice['fields'].get(self.discriminator)
|
||||
if field is None:
|
||||
raise PydanticUserError(
|
||||
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
|
||||
)
|
||||
return self._infer_discriminator_values_for_field(field, source)
|
||||
|
||||
def _infer_discriminator_values_for_model_choice(
|
||||
self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None
|
||||
) -> list[str | int]:
|
||||
source = 'ModelFields' if source_name is None else f'Model {source_name!r}'
|
||||
field = choice['fields'].get(self.discriminator)
|
||||
if field is None:
|
||||
raise PydanticUserError(
|
||||
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
|
||||
)
|
||||
return self._infer_discriminator_values_for_field(field, source)
|
||||
|
||||
def _infer_discriminator_values_for_dataclass_choice(
|
||||
self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None
|
||||
) -> list[str | int]:
|
||||
source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}'
|
||||
for field in choice['fields']:
|
||||
if field['name'] == self.discriminator:
|
||||
break
|
||||
else:
|
||||
raise PydanticUserError(
|
||||
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
|
||||
)
|
||||
return self._infer_discriminator_values_for_field(field, source)
|
||||
|
||||
def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]:
|
||||
if field['type'] == 'computed-field':
|
||||
# This should never occur as a discriminator, as it is only relevant to serialization
|
||||
return []
|
||||
alias = field.get('validation_alias', self.discriminator)
|
||||
if not isinstance(alias, str):
|
||||
raise PydanticUserError(
|
||||
f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type'
|
||||
)
|
||||
if self._discriminator_alias is None:
|
||||
self._discriminator_alias = alias
|
||||
elif self._discriminator_alias != alias:
|
||||
raise PydanticUserError(
|
||||
f'Aliases for discriminator {self.discriminator!r} must be the same '
|
||||
f'(got {alias}, {self._discriminator_alias})',
|
||||
code='discriminator-alias',
|
||||
)
|
||||
return self._infer_discriminator_values_for_inner_schema(field['schema'], source)
|
||||
|
||||
def _infer_discriminator_values_for_inner_schema(
|
||||
self, schema: core_schema.CoreSchema, source: str
|
||||
) -> list[str | int]:
|
||||
"""When inferring discriminator values for a field, we typically extract the expected values from a literal
|
||||
schema. This function does that, but also handles nested unions and defaults.
|
||||
"""
|
||||
if schema['type'] == 'literal':
|
||||
return schema['expected']
|
||||
|
||||
elif schema['type'] == 'union':
|
||||
# Generally when multiple values are allowed they should be placed in a single `Literal`, but
|
||||
# we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s.
|
||||
# For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]`
|
||||
values: list[Any] = []
|
||||
for choice in schema['choices']:
|
||||
choice_schema = choice[0] if isinstance(choice, tuple) else choice
|
||||
choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source)
|
||||
values.extend(choice_values)
|
||||
return values
|
||||
|
||||
elif schema['type'] == 'default':
|
||||
# This will happen if the field has a default value; we ignore it while extracting the discriminator values
|
||||
return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
|
||||
|
||||
elif schema['type'] == 'function-after':
|
||||
# After validators don't affect the discriminator values
|
||||
return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
|
||||
|
||||
elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}:
|
||||
validator_type = repr(schema['type'].split('-')[1])
|
||||
raise PydanticUserError(
|
||||
f'Cannot use a mode={validator_type} validator in the'
|
||||
f' discriminator field {self.discriminator!r} of {source}',
|
||||
code='discriminator-validator',
|
||||
)
|
||||
|
||||
else:
|
||||
raise PydanticUserError(
|
||||
f'{source} needs field {self.discriminator!r} to be of type `Literal`',
|
||||
code='discriminator-needs-literal',
|
||||
)
|
||||
|
||||
def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None:
|
||||
"""This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the
|
||||
provided `choice`, validating that none of these values already map to another (different) choice.
|
||||
"""
|
||||
for discriminator_value in values:
|
||||
if discriminator_value in self._tagged_union_choices:
|
||||
# It is okay if `value` is already in tagged_union_choices as long as it maps to the same value.
|
||||
# Because tagged_union_choices may map values to other values, we need to walk the choices dict
|
||||
# until we get to a "real" choice, and confirm that is equal to the one assigned.
|
||||
existing_choice = self._tagged_union_choices[discriminator_value]
|
||||
if existing_choice != choice:
|
||||
raise TypeError(
|
||||
f'Value {discriminator_value!r} for discriminator '
|
||||
f'{self.discriminator!r} mapped to multiple choices'
|
||||
)
|
||||
else:
|
||||
self._tagged_union_choices[discriminator_value] = choice
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Utilities related to attribute docstring extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
|
||||
class DocstringVisitor(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.target: str | None = None
|
||||
self.attrs: dict[str, str] = {}
|
||||
self.previous_node_type: type[ast.AST] | None = None
|
||||
|
||||
def visit(self, node: ast.AST) -> Any:
|
||||
node_result = super().visit(node)
|
||||
self.previous_node_type = type(node)
|
||||
return node_result
|
||||
|
||||
def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.target = node.target.id
|
||||
|
||||
def visit_Expr(self, node: ast.Expr) -> Any:
|
||||
if (
|
||||
isinstance(node.value, ast.Constant)
|
||||
and isinstance(node.value.value, str)
|
||||
and self.previous_node_type is ast.AnnAssign
|
||||
):
|
||||
docstring = inspect.cleandoc(node.value.value)
|
||||
if self.target:
|
||||
self.attrs[self.target] = docstring
|
||||
self.target = None
|
||||
|
||||
|
||||
def _dedent_source_lines(source: list[str]) -> str:
|
||||
# Required for nested class definitions, e.g. in a function block
|
||||
dedent_source = textwrap.dedent(''.join(source))
|
||||
if dedent_source.startswith((' ', '\t')):
|
||||
# We are in the case where there's a dedented (usually multiline) string
|
||||
# at a lower indentation level than the class itself. We wrap our class
|
||||
# in a function as a workaround.
|
||||
dedent_source = f'def dedent_workaround():\n{dedent_source}'
|
||||
return dedent_source
|
||||
|
||||
|
||||
def _extract_source_from_frame(cls: type[Any]) -> list[str] | None:
|
||||
frame = inspect.currentframe()
|
||||
|
||||
while frame:
|
||||
if inspect.getmodule(frame) is inspect.getmodule(cls):
|
||||
lnum = frame.f_lineno
|
||||
try:
|
||||
lines, _ = inspect.findsource(frame)
|
||||
except OSError:
|
||||
# Source can't be retrieved (maybe because running in an interactive terminal),
|
||||
# we don't want to error here.
|
||||
pass
|
||||
else:
|
||||
block_lines = inspect.getblock(lines[lnum - 1 :])
|
||||
dedent_source = _dedent_source_lines(block_lines)
|
||||
try:
|
||||
block_tree = ast.parse(dedent_source)
|
||||
except SyntaxError:
|
||||
pass
|
||||
else:
|
||||
stmt = block_tree.body[0]
|
||||
if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround':
|
||||
# `_dedent_source_lines` wrapped the class around the workaround function
|
||||
stmt = stmt.body[0]
|
||||
if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__:
|
||||
return block_lines
|
||||
|
||||
frame = frame.f_back
|
||||
|
||||
|
||||
def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]:
|
||||
"""Map model attributes and their corresponding docstring.
|
||||
|
||||
Args:
|
||||
cls: The class of the Pydantic model to inspect.
|
||||
use_inspect: Whether to skip usage of frames to find the object and use
|
||||
the `inspect` module instead.
|
||||
|
||||
Returns:
|
||||
A mapping containing attribute names and their corresponding docstring.
|
||||
"""
|
||||
if use_inspect:
|
||||
# Might not work as expected if two classes have the same name in the same source file.
|
||||
try:
|
||||
source, _ = inspect.getsourcelines(cls)
|
||||
except OSError:
|
||||
return {}
|
||||
else:
|
||||
source = _extract_source_from_frame(cls)
|
||||
|
||||
if not source:
|
||||
return {}
|
||||
|
||||
dedent_source = _dedent_source_lines(source)
|
||||
|
||||
visitor = DocstringVisitor()
|
||||
visitor.visit(ast.parse(dedent_source))
|
||||
return visitor.attrs
|
||||
333
venv/lib/python3.11/site-packages/pydantic/_internal/_fields.py
Normal file
333
venv/lib/python3.11/site-packages/pydantic/_internal/_fields.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
import warnings
|
||||
from copy import copy
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from pydantic.errors import PydanticUserError
|
||||
|
||||
from . import _typing_extra
|
||||
from ._config import ConfigWrapper
|
||||
from ._docs_extraction import extract_docstrings_from_cls
|
||||
from ._import_utils import import_cached_base_model, import_cached_field_info
|
||||
from ._repr import Representation
|
||||
from ._typing_extra import get_cls_type_hints_lenient, is_classvar, is_finalvar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from annotated_types import BaseMetadata
|
||||
|
||||
from ..fields import FieldInfo
|
||||
from ..main import BaseModel
|
||||
from ._dataclasses import StandardDataclass
|
||||
from ._decorators import DecoratorInfos
|
||||
|
||||
|
||||
class PydanticMetadata(Representation):
|
||||
"""Base class for annotation markers like `Strict`."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
def pydantic_general_metadata(**metadata: Any) -> BaseMetadata:
|
||||
"""Create a new `_PydanticGeneralMetadata` class with the given metadata.
|
||||
|
||||
Args:
|
||||
**metadata: The metadata to add.
|
||||
|
||||
Returns:
|
||||
The new `_PydanticGeneralMetadata` class.
|
||||
"""
|
||||
return _general_metadata_cls()(metadata) # type: ignore
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _general_metadata_cls() -> type[BaseMetadata]:
|
||||
"""Do it this way to avoid importing `annotated_types` at import time."""
|
||||
from annotated_types import BaseMetadata
|
||||
|
||||
class _PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
|
||||
"""Pydantic general metadata like `max_digits`."""
|
||||
|
||||
def __init__(self, metadata: Any):
|
||||
self.__dict__ = metadata
|
||||
|
||||
return _PydanticGeneralMetadata # type: ignore
|
||||
|
||||
|
||||
def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper) -> None:
|
||||
if config_wrapper.use_attribute_docstrings:
|
||||
fields_docs = extract_docstrings_from_cls(cls)
|
||||
for ann_name, field_info in fields.items():
|
||||
if field_info.description is None and ann_name in fields_docs:
|
||||
field_info.description = fields_docs[ann_name]
|
||||
|
||||
|
||||
def collect_model_fields( # noqa: C901
|
||||
cls: type[BaseModel],
|
||||
bases: tuple[type[Any], ...],
|
||||
config_wrapper: ConfigWrapper,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
*,
|
||||
typevars_map: dict[Any, Any] | None = None,
|
||||
) -> tuple[dict[str, FieldInfo], set[str]]:
|
||||
"""Collect the fields of a nascent pydantic model.
|
||||
|
||||
Also collect the names of any ClassVars present in the type hints.
|
||||
|
||||
The returned value is a tuple of two items: the fields dict, and the set of ClassVar names.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
bases: Parents of the class, generally `cls.__bases__`.
|
||||
config_wrapper: The config wrapper instance.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
typevars_map: A dictionary mapping type variables to their concrete types.
|
||||
|
||||
Returns:
|
||||
A tuple contains fields and class variables.
|
||||
|
||||
Raises:
|
||||
NameError:
|
||||
- If there is a conflict between a field name and protected namespaces.
|
||||
- If there is a field other than `root` in `RootModel`.
|
||||
- If a field shadows an attribute in the parent model.
|
||||
"""
|
||||
BaseModel = import_cached_base_model()
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
type_hints = get_cls_type_hints_lenient(cls, types_namespace)
|
||||
|
||||
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
|
||||
# annotations is only used for finding fields in parent classes
|
||||
annotations = cls.__dict__.get('__annotations__', {})
|
||||
fields: dict[str, FieldInfo] = {}
|
||||
|
||||
class_vars: set[str] = set()
|
||||
for ann_name, ann_type in type_hints.items():
|
||||
if ann_name == 'model_config':
|
||||
# We never want to treat `model_config` as a field
|
||||
# Note: we may need to change this logic if/when we introduce a `BareModel` class with no
|
||||
# protected namespaces (where `model_config` might be allowed as a field name)
|
||||
continue
|
||||
for protected_namespace in config_wrapper.protected_namespaces:
|
||||
if ann_name.startswith(protected_namespace):
|
||||
for b in bases:
|
||||
if hasattr(b, ann_name):
|
||||
if not (issubclass(b, BaseModel) and ann_name in b.model_fields):
|
||||
raise NameError(
|
||||
f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}'
|
||||
f' of protected namespace "{protected_namespace}".'
|
||||
)
|
||||
else:
|
||||
valid_namespaces = tuple(
|
||||
x for x in config_wrapper.protected_namespaces if not ann_name.startswith(x)
|
||||
)
|
||||
warnings.warn(
|
||||
f'Field "{ann_name}" in {cls.__name__} has conflict with protected namespace "{protected_namespace}".'
|
||||
'\n\nYou may be able to resolve this warning by setting'
|
||||
f" `model_config['protected_namespaces'] = {valid_namespaces}`.",
|
||||
UserWarning,
|
||||
)
|
||||
if is_classvar(ann_type):
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)):
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
if not is_valid_field_name(ann_name):
|
||||
continue
|
||||
if cls.__pydantic_root_model__ and ann_name != 'root':
|
||||
raise NameError(
|
||||
f"Unexpected field with name {ann_name!r}; only 'root' is allowed as a field of a `RootModel`"
|
||||
)
|
||||
|
||||
# when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get
|
||||
# "... shadows an attribute" warnings
|
||||
generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin')
|
||||
for base in bases:
|
||||
dataclass_fields = {
|
||||
field.name for field in (dataclasses.fields(base) if dataclasses.is_dataclass(base) else ())
|
||||
}
|
||||
if hasattr(base, ann_name):
|
||||
if base is generic_origin:
|
||||
# Don't warn when "shadowing" of attributes in parametrized generics
|
||||
continue
|
||||
|
||||
if ann_name in dataclass_fields:
|
||||
# Don't warn when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
|
||||
# on the class instance.
|
||||
continue
|
||||
|
||||
if ann_name not in annotations:
|
||||
# Don't warn when a field exists in a parent class but has not been defined in the current class
|
||||
continue
|
||||
|
||||
warnings.warn(
|
||||
f'Field name "{ann_name}" in "{cls.__qualname__}" shadows an attribute in parent '
|
||||
f'"{base.__qualname__}"',
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
try:
|
||||
default = getattr(cls, ann_name, PydanticUndefined)
|
||||
if default is PydanticUndefined:
|
||||
raise AttributeError
|
||||
except AttributeError:
|
||||
if ann_name in annotations:
|
||||
field_info = FieldInfo_.from_annotation(ann_type)
|
||||
else:
|
||||
# if field has no default value and is not in __annotations__ this means that it is
|
||||
# defined in a base class and we can take it from there
|
||||
model_fields_lookup: dict[str, FieldInfo] = {}
|
||||
for x in cls.__bases__[::-1]:
|
||||
model_fields_lookup.update(getattr(x, 'model_fields', {}))
|
||||
if ann_name in model_fields_lookup:
|
||||
# The field was present on one of the (possibly multiple) base classes
|
||||
# copy the field to make sure typevar substitutions don't cause issues with the base classes
|
||||
field_info = copy(model_fields_lookup[ann_name])
|
||||
else:
|
||||
# The field was not found on any base classes; this seems to be caused by fields not getting
|
||||
# generated thanks to models not being fully defined while initializing recursive models.
|
||||
# Nothing stops us from just creating a new FieldInfo for this type hint, so we do this.
|
||||
field_info = FieldInfo_.from_annotation(ann_type)
|
||||
else:
|
||||
_warn_on_nested_alias_in_annotation(ann_type, ann_name)
|
||||
field_info = FieldInfo_.from_annotated_attribute(ann_type, default)
|
||||
# attributes which are fields are removed from the class namespace:
|
||||
# 1. To match the behaviour of annotation-only fields
|
||||
# 2. To avoid false positives in the NameError check above
|
||||
try:
|
||||
delattr(cls, ann_name)
|
||||
except AttributeError:
|
||||
pass # indicates the attribute was on a parent class
|
||||
|
||||
# Use cls.__dict__['__pydantic_decorators__'] instead of cls.__pydantic_decorators__
|
||||
# to make sure the decorators have already been built for this exact class
|
||||
decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__']
|
||||
if ann_name in decorators.computed_fields:
|
||||
raise ValueError("you can't override a field with a computed field")
|
||||
fields[ann_name] = field_info
|
||||
|
||||
if typevars_map:
|
||||
for field in fields.values():
|
||||
field.apply_typevars_map(typevars_map, types_namespace)
|
||||
|
||||
_update_fields_from_docstrings(cls, fields, config_wrapper)
|
||||
|
||||
return fields, class_vars
|
||||
|
||||
|
||||
def _warn_on_nested_alias_in_annotation(ann_type: type[Any], ann_name: str) -> None:
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
args = getattr(ann_type, '__args__', None)
|
||||
if args:
|
||||
for anno_arg in args:
|
||||
if _typing_extra.is_annotated(anno_arg):
|
||||
for anno_type_arg in _typing_extra.get_args(anno_arg):
|
||||
if isinstance(anno_type_arg, FieldInfo) and anno_type_arg.alias is not None:
|
||||
warnings.warn(
|
||||
f'`alias` specification on field "{ann_name}" must be set on outermost annotation to take effect.',
|
||||
UserWarning,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
if not is_finalvar(type_):
|
||||
return False
|
||||
elif val is PydanticUndefined:
|
||||
return False
|
||||
elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def collect_dataclass_fields(
|
||||
cls: type[StandardDataclass],
|
||||
types_namespace: dict[str, Any] | None,
|
||||
*,
|
||||
typevars_map: dict[Any, Any] | None = None,
|
||||
config_wrapper: ConfigWrapper | None = None,
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Collect the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
cls: dataclass.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
typevars_map: A dictionary mapping type variables to their concrete types.
|
||||
config_wrapper: The config wrapper instance.
|
||||
|
||||
Returns:
|
||||
The dataclass fields.
|
||||
"""
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
fields: dict[str, FieldInfo] = {}
|
||||
dataclass_fields: dict[str, dataclasses.Field] = cls.__dataclass_fields__
|
||||
cls_localns = dict(vars(cls)) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead
|
||||
|
||||
source_module = sys.modules.get(cls.__module__)
|
||||
if source_module is not None:
|
||||
types_namespace = {**source_module.__dict__, **(types_namespace or {})}
|
||||
|
||||
for ann_name, dataclass_field in dataclass_fields.items():
|
||||
ann_type = _typing_extra.eval_type_lenient(dataclass_field.type, types_namespace, cls_localns)
|
||||
if is_classvar(ann_type):
|
||||
continue
|
||||
|
||||
if (
|
||||
not dataclass_field.init
|
||||
and dataclass_field.default == dataclasses.MISSING
|
||||
and dataclass_field.default_factory == dataclasses.MISSING
|
||||
):
|
||||
# TODO: We should probably do something with this so that validate_assignment behaves properly
|
||||
# Issue: https://github.com/pydantic/pydantic/issues/5470
|
||||
continue
|
||||
|
||||
if isinstance(dataclass_field.default, FieldInfo_):
|
||||
if dataclass_field.default.init_var:
|
||||
if dataclass_field.default.init is False:
|
||||
raise PydanticUserError(
|
||||
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
|
||||
code='clashing-init-and-init-var',
|
||||
)
|
||||
|
||||
# TODO: same note as above re validate_assignment
|
||||
continue
|
||||
field_info = FieldInfo_.from_annotated_attribute(ann_type, dataclass_field.default)
|
||||
else:
|
||||
field_info = FieldInfo_.from_annotated_attribute(ann_type, dataclass_field)
|
||||
|
||||
fields[ann_name] = field_info
|
||||
|
||||
if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo_):
|
||||
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
|
||||
setattr(cls, ann_name, field_info.default)
|
||||
|
||||
if typevars_map:
|
||||
for field in fields.values():
|
||||
field.apply_typevars_map(typevars_map, types_namespace)
|
||||
|
||||
if config_wrapper is not None:
|
||||
_update_fields_from_docstrings(cls, fields, config_wrapper)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def is_valid_field_name(name: str) -> bool:
|
||||
return not name.startswith('_')
|
||||
|
||||
|
||||
def is_valid_privateattr_name(name: str) -> bool:
|
||||
return name.startswith('_') and not name.startswith('__')
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class PydanticRecursiveRef:
|
||||
type_ref: str
|
||||
|
||||
__name__ = 'PydanticRecursiveRef'
|
||||
__hash__ = object.__hash__
|
||||
|
||||
def __call__(self) -> None:
|
||||
"""Defining __call__ is necessary for the `typing` module to let you use an instance of
|
||||
this class as the result of resolving a standard ForwardRef.
|
||||
"""
|
||||
|
||||
def __or__(self, other):
|
||||
return Union[self, other] # type: ignore
|
||||
|
||||
def __ror__(self, other):
|
||||
return Union[other, self] # type: ignore
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,518 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from collections import ChainMap
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from types import prepare_class
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Mapping, MutableMapping, Tuple, TypeVar
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import typing_extensions
|
||||
|
||||
from ._core_utils import get_type_ref
|
||||
from ._forward_ref import PydanticRecursiveRef
|
||||
from ._typing_extra import TypeVarType, typing_base
|
||||
from ._utils import all_identical, is_model_class
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import _UnionGenericAlias # type: ignore[attr-defined]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..main import BaseModel
|
||||
|
||||
GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
|
||||
|
||||
# Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
|
||||
# Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
|
||||
# By chaining the WeakValuesDict with a LimitedDict, we have a way to retain caching for all types with references,
|
||||
# while also retaining a limited number of types even without references. This is generally enough to build
|
||||
# specific recursive generic models without losing required items out of the cache.
|
||||
|
||||
KT = TypeVar('KT')
|
||||
VT = TypeVar('VT')
|
||||
_LIMITED_DICT_SIZE = 100
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class LimitedDict(dict, MutableMapping[KT, VT]):
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE): ...
|
||||
|
||||
else:
|
||||
|
||||
class LimitedDict(dict):
|
||||
"""Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
|
||||
|
||||
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
|
||||
"""
|
||||
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
|
||||
self.size_limit = size_limit
|
||||
super().__init__()
|
||||
|
||||
def __setitem__(self, key: Any, value: Any, /) -> None:
|
||||
super().__setitem__(key, value)
|
||||
if len(self) > self.size_limit:
|
||||
excess = len(self) - self.size_limit + self.size_limit // 10
|
||||
to_remove = list(self.keys())[:excess]
|
||||
for k in to_remove:
|
||||
del self[k]
|
||||
|
||||
|
||||
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
|
||||
# once they are no longer referenced by the caller.
|
||||
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
|
||||
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
|
||||
else:
|
||||
GenericTypesCache = WeakValueDictionary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class DeepChainMap(ChainMap[KT, VT]): # type: ignore
|
||||
...
|
||||
|
||||
else:
|
||||
|
||||
class DeepChainMap(ChainMap):
|
||||
"""Variant of ChainMap that allows direct updates to inner scopes.
|
||||
|
||||
Taken from https://docs.python.org/3/library/collections.html#collections.ChainMap,
|
||||
with some light modifications for this use case.
|
||||
"""
|
||||
|
||||
def clear(self) -> None:
|
||||
for mapping in self.maps:
|
||||
mapping.clear()
|
||||
|
||||
def __setitem__(self, key: KT, value: VT) -> None:
|
||||
for mapping in self.maps:
|
||||
mapping[key] = value
|
||||
|
||||
def __delitem__(self, key: KT) -> None:
|
||||
hit = False
|
||||
for mapping in self.maps:
|
||||
if key in mapping:
|
||||
del mapping[key]
|
||||
hit = True
|
||||
if not hit:
|
||||
raise KeyError(key)
|
||||
|
||||
|
||||
# Despite the fact that LimitedDict _seems_ no longer necessary, I'm very nervous to actually remove it
|
||||
# and discover later on that we need to re-add all this infrastructure...
|
||||
# _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
|
||||
|
||||
_GENERIC_TYPES_CACHE = GenericTypesCache()
|
||||
|
||||
|
||||
class PydanticGenericMetadata(typing_extensions.TypedDict):
|
||||
origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
|
||||
args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
|
||||
parameters: tuple[type[Any], ...] # analogous to typing.Generic.__parameters__
|
||||
|
||||
|
||||
def create_generic_submodel(
|
||||
model_name: str, origin: type[BaseModel], args: tuple[Any, ...], params: tuple[Any, ...]
|
||||
) -> type[BaseModel]:
|
||||
"""Dynamically create a submodel of a provided (generic) BaseModel.
|
||||
|
||||
This is used when producing concrete parametrizations of generic models. This function
|
||||
only *creates* the new subclass; the schema/validators/serialization must be updated to
|
||||
reflect a concrete parametrization elsewhere.
|
||||
|
||||
Args:
|
||||
model_name: The name of the newly created model.
|
||||
origin: The base class for the new model to inherit from.
|
||||
args: A tuple of generic metadata arguments.
|
||||
params: A tuple of generic metadata parameters.
|
||||
|
||||
Returns:
|
||||
The created submodel.
|
||||
"""
|
||||
namespace: dict[str, Any] = {'__module__': origin.__module__}
|
||||
bases = (origin,)
|
||||
meta, ns, kwds = prepare_class(model_name, bases)
|
||||
namespace.update(ns)
|
||||
created_model = meta(
|
||||
model_name,
|
||||
bases,
|
||||
namespace,
|
||||
__pydantic_generic_metadata__={
|
||||
'origin': origin,
|
||||
'args': args,
|
||||
'parameters': params,
|
||||
},
|
||||
__pydantic_reset_parent_namespace__=False,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
model_module, called_globally = _get_caller_frame_info(depth=3)
|
||||
if called_globally: # create global reference and therefore allow pickling
|
||||
object_by_reference = None
|
||||
reference_name = model_name
|
||||
reference_module_globals = sys.modules[created_model.__module__].__dict__
|
||||
while object_by_reference is not created_model:
|
||||
object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
|
||||
reference_name += '_'
|
||||
|
||||
return created_model
|
||||
|
||||
|
||||
def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
|
||||
"""Used inside a function to check whether it was called globally.
|
||||
|
||||
Args:
|
||||
depth: The depth to get the frame.
|
||||
|
||||
Returns:
|
||||
A tuple contains `module_name` and `called_globally`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the function is not called inside a function.
|
||||
"""
|
||||
try:
|
||||
previous_caller_frame = sys._getframe(depth)
|
||||
except ValueError as e:
|
||||
raise RuntimeError('This function must be used inside another function') from e
|
||||
except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
|
||||
return None, False
|
||||
frame_globals = previous_caller_frame.f_globals
|
||||
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
|
||||
|
||||
|
||||
DictValues: type[Any] = {}.values().__class__
|
||||
|
||||
|
||||
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
|
||||
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
|
||||
|
||||
This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
|
||||
since __parameters__ of (nested) generic BaseModel subclasses won't show up in that list.
|
||||
"""
|
||||
if isinstance(v, TypeVar):
|
||||
yield v
|
||||
elif is_model_class(v):
|
||||
yield from v.__pydantic_generic_metadata__['parameters']
|
||||
elif isinstance(v, (DictValues, list)):
|
||||
for var in v:
|
||||
yield from iter_contained_typevars(var)
|
||||
else:
|
||||
args = get_args(v)
|
||||
for arg in args:
|
||||
yield from iter_contained_typevars(arg)
|
||||
|
||||
|
||||
def get_args(v: Any) -> Any:
|
||||
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
|
||||
if pydantic_generic_metadata:
|
||||
return pydantic_generic_metadata.get('args')
|
||||
return typing_extensions.get_args(v)
|
||||
|
||||
|
||||
def get_origin(v: Any) -> Any:
|
||||
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
|
||||
if pydantic_generic_metadata:
|
||||
return pydantic_generic_metadata.get('origin')
|
||||
return typing_extensions.get_origin(v)
|
||||
|
||||
|
||||
def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None:
|
||||
"""Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
|
||||
`replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
|
||||
"""
|
||||
origin = get_origin(cls)
|
||||
if origin is None:
|
||||
return None
|
||||
if not hasattr(origin, '__parameters__'):
|
||||
return None
|
||||
|
||||
# In this case, we know that cls is a _GenericAlias, and origin is the generic type
|
||||
# So it is safe to access cls.__args__ and origin.__parameters__
|
||||
args: tuple[Any, ...] = cls.__args__ # type: ignore
|
||||
parameters: tuple[TypeVarType, ...] = origin.__parameters__
|
||||
return dict(zip(parameters, args))
|
||||
|
||||
|
||||
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | None:
|
||||
"""Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
|
||||
with the `replace_types` function.
|
||||
|
||||
Since BaseModel.__class_getitem__ does not produce a typing._GenericAlias, and the BaseModel generic info is
|
||||
stored in the __pydantic_generic_metadata__ attribute, we need special handling here.
|
||||
"""
|
||||
# TODO: This could be unified with `get_standard_typevars_map` if we stored the generic metadata
|
||||
# in the __origin__, __args__, and __parameters__ attributes of the model.
|
||||
generic_metadata = cls.__pydantic_generic_metadata__
|
||||
origin = generic_metadata['origin']
|
||||
args = generic_metadata['args']
|
||||
return dict(zip(iter_contained_typevars(origin), args))
|
||||
|
||||
|
||||
def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
|
||||
|
||||
Args:
|
||||
type_: The class or generic alias.
|
||||
type_map: Mapping from `TypeVar` instance to concrete types.
|
||||
|
||||
Returns:
|
||||
A new type representing the basic structure of `type_` with all
|
||||
`typevar_map` keys recursively replaced.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from pydantic._internal._generics import replace_types
|
||||
|
||||
replace_types(Tuple[str, Union[List[str], float]], {str: int})
|
||||
#> Tuple[int, Union[List[int], float]]
|
||||
```
|
||||
"""
|
||||
if not type_map:
|
||||
return type_
|
||||
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if origin_type is typing_extensions.Annotated:
|
||||
annotated_type, *annotations = type_args
|
||||
annotated = replace_types(annotated_type, type_map)
|
||||
for annotation in annotations:
|
||||
annotated = typing_extensions.Annotated[annotated, annotation]
|
||||
return annotated
|
||||
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
if type_args:
|
||||
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
|
||||
if all_identical(type_args, resolved_type_args):
|
||||
# If all arguments are the same, there is no need to modify the
|
||||
# type or create a new object at all
|
||||
return type_
|
||||
if (
|
||||
origin_type is not None
|
||||
and isinstance(type_, typing_base)
|
||||
and not isinstance(origin_type, typing_base)
|
||||
and getattr(type_, '_name', None) is not None
|
||||
):
|
||||
# In python < 3.9 generic aliases don't exist so any of these like `list`,
|
||||
# `type` or `collections.abc.Callable` need to be translated.
|
||||
# See: https://www.python.org/dev/peps/pep-0585
|
||||
origin_type = getattr(typing, type_._name)
|
||||
assert origin_type is not None
|
||||
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
|
||||
# We also cannot use isinstance() since we have to compare types.
|
||||
if sys.version_info >= (3, 10) and origin_type is types.UnionType:
|
||||
return _UnionGenericAlias(origin_type, resolved_type_args)
|
||||
# NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
|
||||
return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
|
||||
|
||||
# We handle pydantic generic models separately as they don't have the same
|
||||
# semantics as "typing" classes or generic aliases
|
||||
|
||||
if not origin_type and is_model_class(type_):
|
||||
parameters = type_.__pydantic_generic_metadata__['parameters']
|
||||
if not parameters:
|
||||
return type_
|
||||
resolved_type_args = tuple(replace_types(t, type_map) for t in parameters)
|
||||
if all_identical(parameters, resolved_type_args):
|
||||
return type_
|
||||
return type_[resolved_type_args]
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, (List, list)):
|
||||
resolved_list = [replace_types(element, type_map) for element in type_]
|
||||
if all_identical(type_, resolved_list):
|
||||
return type_
|
||||
return resolved_list
|
||||
|
||||
# If all else fails, we try to resolve the type directly and otherwise just
|
||||
# return the input with no modifications.
|
||||
return type_map.get(type_, type_)
|
||||
|
||||
|
||||
def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
|
||||
"""Checks if the type, or any of its arbitrary nested args, satisfy
|
||||
`isinstance(<type>, isinstance_target)`.
|
||||
"""
|
||||
if isinstance(type_, isinstance_target):
|
||||
return True
|
||||
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if origin_type is typing_extensions.Annotated:
|
||||
annotated_type, *annotations = type_args
|
||||
return has_instance_in_type(annotated_type, isinstance_target)
|
||||
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
for arg in type_args:
|
||||
if has_instance_in_type(arg, isinstance_target):
|
||||
return True
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, (List, list)) and not isinstance(type_, typing_extensions.ParamSpec):
|
||||
for element in type_:
|
||||
if has_instance_in_type(element, isinstance_target):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
|
||||
"""Check the generic model parameters count is equal.
|
||||
|
||||
Args:
|
||||
cls: The generic model.
|
||||
parameters: A tuple of passed parameters to the generic model.
|
||||
|
||||
Raises:
|
||||
TypeError: If the passed parameters count is not equal to generic model parameters count.
|
||||
"""
|
||||
actual = len(parameters)
|
||||
expected = len(cls.__pydantic_generic_metadata__['parameters'])
|
||||
if actual != expected:
|
||||
description = 'many' if actual > expected else 'few'
|
||||
raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
|
||||
|
||||
|
||||
_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def generic_recursion_self_type(
|
||||
origin: type[BaseModel], args: tuple[Any, ...]
|
||||
) -> Iterator[PydanticRecursiveRef | None]:
|
||||
"""This contextmanager should be placed around the recursive calls used to build a generic type,
|
||||
and accept as arguments the generic origin type and the type arguments being passed to it.
|
||||
|
||||
If the same origin and arguments are observed twice, it implies that a self-reference placeholder
|
||||
can be used while building the core schema, and will produce a schema_ref that will be valid in the
|
||||
final parent schema.
|
||||
"""
|
||||
previously_seen_type_refs = _generic_recursion_cache.get()
|
||||
if previously_seen_type_refs is None:
|
||||
previously_seen_type_refs = set()
|
||||
token = _generic_recursion_cache.set(previously_seen_type_refs)
|
||||
else:
|
||||
token = None
|
||||
|
||||
try:
|
||||
type_ref = get_type_ref(origin, args_override=args)
|
||||
if type_ref in previously_seen_type_refs:
|
||||
self_type = PydanticRecursiveRef(type_ref=type_ref)
|
||||
yield self_type
|
||||
else:
|
||||
previously_seen_type_refs.add(type_ref)
|
||||
yield None
|
||||
finally:
|
||||
if token:
|
||||
_generic_recursion_cache.reset(token)
|
||||
|
||||
|
||||
def recursively_defined_type_refs() -> set[str]:
|
||||
visited = _generic_recursion_cache.get()
|
||||
if not visited:
|
||||
return set() # not in a generic recursion, so there are no types
|
||||
|
||||
return visited.copy() # don't allow modifications
|
||||
|
||||
|
||||
def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any) -> type[BaseModel] | None:
|
||||
"""The use of a two-stage cache lookup approach was necessary to have the highest performance possible for
|
||||
repeated calls to `__class_getitem__` on generic types (which may happen in tighter loops during runtime),
|
||||
while still ensuring that certain alternative parametrizations ultimately resolve to the same type.
|
||||
|
||||
As a concrete example, this approach was necessary to make Model[List[T]][int] equal to Model[List[int]].
|
||||
The approach could be modified to not use two different cache keys at different points, but the
|
||||
_early_cache_key is optimized to be as quick to compute as possible (for repeated-access speed), and the
|
||||
_late_cache_key is optimized to be as "correct" as possible, so that two types that will ultimately be the
|
||||
same after resolving the type arguments will always produce cache hits.
|
||||
|
||||
If we wanted to move to only using a single cache key per type, we would either need to always use the
|
||||
slower/more computationally intensive logic associated with _late_cache_key, or would need to accept
|
||||
that Model[List[T]][int] is a different type than Model[List[T]][int]. Because we rely on subclass relationships
|
||||
during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
|
||||
equal.
|
||||
"""
|
||||
return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values))
|
||||
|
||||
|
||||
def get_cached_generic_type_late(
|
||||
parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
|
||||
) -> type[BaseModel] | None:
|
||||
"""See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
|
||||
cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values))
|
||||
if cached is not None:
|
||||
set_cached_generic_type(parent, typevar_values, cached, origin, args)
|
||||
return cached
|
||||
|
||||
|
||||
def set_cached_generic_type(
|
||||
parent: type[BaseModel],
|
||||
typevar_values: tuple[Any, ...],
|
||||
type_: type[BaseModel],
|
||||
origin: type[BaseModel] | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
) -> None:
|
||||
"""See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
|
||||
two different keys.
|
||||
"""
|
||||
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_
|
||||
if len(typevar_values) == 1:
|
||||
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_
|
||||
if origin and args:
|
||||
_GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_
|
||||
|
||||
|
||||
def _union_orderings_key(typevar_values: Any) -> Any:
|
||||
"""This is intended to help differentiate between Union types with the same arguments in different order.
|
||||
|
||||
Thanks to caching internal to the `typing` module, it is not possible to distinguish between
|
||||
List[Union[int, float]] and List[Union[float, int]] (and similarly for other "parent" origins besides List)
|
||||
because `typing` considers Union[int, float] to be equal to Union[float, int].
|
||||
|
||||
However, you _can_ distinguish between (top-level) Union[int, float] vs. Union[float, int].
|
||||
Because we parse items as the first Union type that is successful, we get slightly more consistent behavior
|
||||
if we make an effort to distinguish the ordering of items in a union. It would be best if we could _always_
|
||||
get the exact-correct order of items in the union, but that would require a change to the `typing` module itself.
|
||||
(See https://github.com/python/cpython/issues/86483 for reference.)
|
||||
"""
|
||||
if isinstance(typevar_values, tuple):
|
||||
args_data = []
|
||||
for value in typevar_values:
|
||||
args_data.append(_union_orderings_key(value))
|
||||
return tuple(args_data)
|
||||
elif typing_extensions.get_origin(typevar_values) is typing.Union:
|
||||
return get_args(typevar_values)
|
||||
else:
|
||||
return ()
|
||||
|
||||
|
||||
def _early_cache_key(cls: type[BaseModel], typevar_values: Any) -> GenericTypesCacheKey:
|
||||
"""This is intended for minimal computational overhead during lookups of cached types.
|
||||
|
||||
Note that this is overly simplistic, and it's possible that two different cls/typevar_values
|
||||
inputs would ultimately result in the same type being created in BaseModel.__class_getitem__.
|
||||
To handle this, we have a fallback _late_cache_key that is checked later if the _early_cache_key
|
||||
lookup fails, and should result in a cache hit _precisely_ when the inputs to __class_getitem__
|
||||
would result in the same type.
|
||||
"""
|
||||
return cls, typevar_values, _union_orderings_key(typevar_values)
|
||||
|
||||
|
||||
def _late_cache_key(origin: type[BaseModel], args: tuple[Any, ...], typevar_values: Any) -> GenericTypesCacheKey:
|
||||
"""This is intended for use later in the process of creating a new type, when we have more information
|
||||
about the exact args that will be passed. If it turns out that a different set of inputs to
|
||||
__class_getitem__ resulted in the same inputs to the generic type creation process, we can still
|
||||
return the cached type, and update the cache with the _early_cache_key as well.
|
||||
"""
|
||||
# The _union_orderings_key is placed at the start here to ensure there cannot be a collision with an
|
||||
# _early_cache_key, as that function will always produce a BaseModel subclass as the first item in the key,
|
||||
# whereas this function will always produce a tuple as the first item in the key.
|
||||
return _union_orderings_key(typevar_values), origin, args
|
||||
27
venv/lib/python3.11/site-packages/pydantic/_internal/_git.py
Normal file
27
venv/lib/python3.11/site-packages/pydantic/_internal/_git.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Git utilities, adopted from mypy's git utilities (https://github.com/python/mypy/blob/master/mypy/git.py)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
def is_git_repo(dir: str) -> bool:
|
||||
"""Is the given directory version-controlled with git?"""
|
||||
return os.path.exists(os.path.join(dir, '.git'))
|
||||
|
||||
|
||||
def have_git() -> bool:
|
||||
"""Can we run the git executable?"""
|
||||
try:
|
||||
subprocess.check_output(['git', '--help'])
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def git_revision(dir: str) -> str:
|
||||
"""Get the SHA-1 of the HEAD of a git repository."""
|
||||
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=dir).decode('utf-8').strip()
|
||||
@@ -0,0 +1,20 @@
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def import_cached_base_model() -> Type['BaseModel']:
|
||||
from pydantic import BaseModel
|
||||
|
||||
return BaseModel
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def import_cached_field_info() -> Type['FieldInfo']:
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
return FieldInfo
|
||||
@@ -0,0 +1,7 @@
|
||||
import sys
|
||||
|
||||
# `slots` is available on Python >= 3.10
|
||||
if sys.version_info >= (3, 10):
|
||||
slots_true = {'slots': True}
|
||||
else:
|
||||
slots_true = {}
|
||||
@@ -0,0 +1,397 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable
|
||||
|
||||
from pydantic_core import CoreSchema, PydanticCustomError, ValidationError, to_jsonable_python
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
from ._fields import PydanticMetadata
|
||||
from ._import_utils import import_cached_field_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..annotated_handlers import GetJsonSchemaHandler
|
||||
|
||||
STRICT = {'strict'}
|
||||
FAIL_FAST = {'fail_fast'}
|
||||
LENGTH_CONSTRAINTS = {'min_length', 'max_length'}
|
||||
INEQUALITY = {'le', 'ge', 'lt', 'gt'}
|
||||
NUMERIC_CONSTRAINTS = {'multiple_of', *INEQUALITY}
|
||||
ALLOW_INF_NAN = {'allow_inf_nan'}
|
||||
|
||||
STR_CONSTRAINTS = {
|
||||
*LENGTH_CONSTRAINTS,
|
||||
*STRICT,
|
||||
'strip_whitespace',
|
||||
'to_lower',
|
||||
'to_upper',
|
||||
'pattern',
|
||||
'coerce_numbers_to_str',
|
||||
}
|
||||
BYTES_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
|
||||
LIST_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
TUPLE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
SET_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
DICT_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
GENERATOR_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
SEQUENCE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *FAIL_FAST}
|
||||
|
||||
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
|
||||
DECIMAL_CONSTRAINTS = {'max_digits', 'decimal_places', *FLOAT_CONSTRAINTS}
|
||||
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
|
||||
BOOL_CONSTRAINTS = STRICT
|
||||
UUID_CONSTRAINTS = STRICT
|
||||
|
||||
DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
LAX_OR_STRICT_CONSTRAINTS = STRICT
|
||||
ENUM_CONSTRAINTS = STRICT
|
||||
COMPLEX_CONSTRAINTS = STRICT
|
||||
|
||||
UNION_CONSTRAINTS = {'union_mode'}
|
||||
URL_CONSTRAINTS = {
|
||||
'max_length',
|
||||
'allowed_schemes',
|
||||
'host_required',
|
||||
'default_host',
|
||||
'default_port',
|
||||
'default_path',
|
||||
}
|
||||
|
||||
TEXT_SCHEMA_TYPES = ('str', 'bytes', 'url', 'multi-host-url')
|
||||
SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT_SCHEMA_TYPES)
|
||||
NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime')
|
||||
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
constraint_schema_pairings: list[tuple[set[str], tuple[str, ...]]] = [
|
||||
(STR_CONSTRAINTS, TEXT_SCHEMA_TYPES),
|
||||
(BYTES_CONSTRAINTS, ('bytes',)),
|
||||
(LIST_CONSTRAINTS, ('list',)),
|
||||
(TUPLE_CONSTRAINTS, ('tuple',)),
|
||||
(SET_CONSTRAINTS, ('set', 'frozenset')),
|
||||
(DICT_CONSTRAINTS, ('dict',)),
|
||||
(GENERATOR_CONSTRAINTS, ('generator',)),
|
||||
(FLOAT_CONSTRAINTS, ('float',)),
|
||||
(INT_CONSTRAINTS, ('int',)),
|
||||
(DATE_TIME_CONSTRAINTS, ('date', 'time', 'datetime', 'timedelta')),
|
||||
# TODO: this is a bit redundant, we could probably avoid some of these
|
||||
(STRICT, (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model')),
|
||||
(UNION_CONSTRAINTS, ('union',)),
|
||||
(URL_CONSTRAINTS, ('url', 'multi-host-url')),
|
||||
(BOOL_CONSTRAINTS, ('bool',)),
|
||||
(UUID_CONSTRAINTS, ('uuid',)),
|
||||
(LAX_OR_STRICT_CONSTRAINTS, ('lax-or-strict',)),
|
||||
(ENUM_CONSTRAINTS, ('enum',)),
|
||||
(DECIMAL_CONSTRAINTS, ('decimal',)),
|
||||
(COMPLEX_CONSTRAINTS, ('complex',)),
|
||||
]
|
||||
|
||||
for constraints, schemas in constraint_schema_pairings:
|
||||
for c in constraints:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[c].update(schemas)
|
||||
|
||||
|
||||
def add_js_update_schema(s: cs.CoreSchema, f: Callable[[], dict[str, Any]]) -> None:
|
||||
def update_js_schema(s: cs.CoreSchema, handler: GetJsonSchemaHandler) -> dict[str, Any]:
|
||||
js_schema = handler(s)
|
||||
js_schema.update(f())
|
||||
return js_schema
|
||||
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
if 'pydantic_js_functions' in s:
|
||||
metadata['pydantic_js_functions'].append(update_js_schema)
|
||||
else:
|
||||
metadata['pydantic_js_functions'] = [update_js_schema]
|
||||
else:
|
||||
s['metadata'] = {'pydantic_js_functions': [update_js_schema]}
|
||||
|
||||
|
||||
def as_jsonable_value(v: Any) -> Any:
|
||||
if type(v) not in (int, str, float, bytes, bool, type(None)):
|
||||
return to_jsonable_python(v)
|
||||
return v
|
||||
|
||||
|
||||
def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
|
||||
"""Expand the annotations.
|
||||
|
||||
Args:
|
||||
annotations: An iterable of annotations.
|
||||
|
||||
Returns:
|
||||
An iterable of expanded annotations.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from annotated_types import Ge, Len
|
||||
|
||||
from pydantic._internal._known_annotated_metadata import expand_grouped_metadata
|
||||
|
||||
print(list(expand_grouped_metadata([Ge(4), Len(5)])))
|
||||
#> [Ge(ge=4), MinLen(min_length=5)]
|
||||
```
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, at.GroupedMetadata):
|
||||
yield from annotation
|
||||
elif isinstance(annotation, FieldInfo):
|
||||
yield from annotation.metadata
|
||||
# this is a bit problematic in that it results in duplicate metadata
|
||||
# all of our "consumers" can handle it, but it is not ideal
|
||||
# we probably should split up FieldInfo into:
|
||||
# - annotated types metadata
|
||||
# - individual metadata known only to Pydantic
|
||||
annotation = copy(annotation)
|
||||
annotation.metadata = []
|
||||
yield annotation
|
||||
else:
|
||||
yield annotation
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_at_to_constraint_map() -> dict[type, str]:
|
||||
"""Return a mapping of annotated types to constraints.
|
||||
|
||||
Normally, we would define a mapping like this in the module scope, but we can't do that
|
||||
because we don't permit module level imports of `annotated_types`, in an attempt to speed up
|
||||
the import time of `pydantic`. We still only want to have this dictionary defined in one place,
|
||||
so we use this function to cache the result.
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
return {
|
||||
at.Gt: 'gt',
|
||||
at.Ge: 'ge',
|
||||
at.Lt: 'lt',
|
||||
at.Le: 'le',
|
||||
at.MultipleOf: 'multiple_of',
|
||||
at.MinLen: 'min_length',
|
||||
at.MaxLen: 'max_length',
|
||||
}
|
||||
|
||||
|
||||
def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | None: # noqa: C901
|
||||
"""Apply `annotation` to `schema` if it is an annotation we know about (Gt, Le, etc.).
|
||||
Otherwise return `None`.
|
||||
|
||||
This does not handle all known annotations. If / when it does, it can always
|
||||
return a CoreSchema and return the unmodified schema if the annotation should be ignored.
|
||||
|
||||
Assumes that GroupedMetadata has already been expanded via `expand_grouped_metadata`.
|
||||
|
||||
Args:
|
||||
annotation: The annotation.
|
||||
schema: The schema.
|
||||
|
||||
Returns:
|
||||
An updated schema with annotation if it is an annotation we know about, `None` otherwise.
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: If `Predicate` fails.
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
from ._validators import forbid_inf_nan_check, get_constraint_validator
|
||||
|
||||
schema = schema.copy()
|
||||
schema_update, other_metadata = collect_known_metadata([annotation])
|
||||
schema_type = schema['type']
|
||||
|
||||
chain_schema_constraints: set[str] = {
|
||||
'pattern',
|
||||
'strip_whitespace',
|
||||
'to_lower',
|
||||
'to_upper',
|
||||
'coerce_numbers_to_str',
|
||||
}
|
||||
chain_schema_steps: list[CoreSchema] = []
|
||||
|
||||
for constraint, value in schema_update.items():
|
||||
if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS:
|
||||
raise ValueError(f'Unknown constraint {constraint}')
|
||||
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]
|
||||
|
||||
# if it becomes necessary to handle more than one constraint
|
||||
# in this recursive case with function-after or function-wrap, we should refactor
|
||||
# this is a bit challenging because we sometimes want to apply constraints to the inner schema,
|
||||
# whereas other times we want to wrap the existing schema with a new one that enforces a new constraint.
|
||||
if schema_type in {'function-before', 'function-wrap', 'function-after'} and constraint == 'strict':
|
||||
schema['schema'] = apply_known_metadata(annotation, schema['schema']) # type: ignore # schema is function schema
|
||||
return schema
|
||||
|
||||
# if we're allowed to apply constraint directly to the schema, like le to int, do that
|
||||
if schema_type in allowed_schemas:
|
||||
if constraint == 'union_mode' and schema_type == 'union':
|
||||
schema['mode'] = value # type: ignore # schema is UnionSchema
|
||||
else:
|
||||
schema[constraint] = value
|
||||
continue
|
||||
|
||||
# else, apply a function after validator to the schema to enforce the corresponding constraint
|
||||
if constraint in chain_schema_constraints:
|
||||
|
||||
def _apply_constraint_with_incompatibility_info(
|
||||
value: Any, handler: cs.ValidatorFunctionWrapHandler
|
||||
) -> Any:
|
||||
try:
|
||||
x = handler(value)
|
||||
except ValidationError as ve:
|
||||
# if the error is about the type, it's likely that the constraint is incompatible the type of the field
|
||||
# for example, the following invalid schema wouldn't be caught during schema build, but rather at this point
|
||||
# with a cryptic 'string_type' error coming from the string validator,
|
||||
# that we'd rather express as a constraint incompatibility error (TypeError)
|
||||
# Annotated[list[int], Field(pattern='abc')]
|
||||
if 'type' in ve.errors()[0]['type']:
|
||||
raise TypeError(
|
||||
f"Unable to apply constraint '{constraint}' to supplied value {value} for schema of type '{schema_type}'" # noqa: B023
|
||||
)
|
||||
raise ve
|
||||
return x
|
||||
|
||||
chain_schema_steps.append(
|
||||
cs.no_info_wrap_validator_function(
|
||||
_apply_constraint_with_incompatibility_info, cs.str_schema(**{constraint: value})
|
||||
)
|
||||
)
|
||||
elif constraint in {*NUMERIC_CONSTRAINTS, *LENGTH_CONSTRAINTS}:
|
||||
if constraint in NUMERIC_CONSTRAINTS:
|
||||
json_schema_constraint = constraint
|
||||
elif constraint in LENGTH_CONSTRAINTS:
|
||||
inner_schema = schema
|
||||
while inner_schema['type'] in {'function-before', 'function-wrap', 'function-after'}:
|
||||
inner_schema = inner_schema['schema'] # type: ignore
|
||||
inner_schema_type = inner_schema['type']
|
||||
if inner_schema_type == 'list' or (
|
||||
inner_schema_type == 'json-or-python' and inner_schema['json_schema']['type'] == 'list' # type: ignore
|
||||
):
|
||||
json_schema_constraint = 'minItems' if constraint == 'min_length' else 'maxItems'
|
||||
else:
|
||||
json_schema_constraint = 'minLength' if constraint == 'min_length' else 'maxLength'
|
||||
|
||||
schema = cs.no_info_after_validator_function(
|
||||
partial(get_constraint_validator(constraint), **{'constraint_value': value}), schema
|
||||
)
|
||||
add_js_update_schema(schema, lambda: {json_schema_constraint: as_jsonable_value(value)}) # noqa: B023
|
||||
elif constraint == 'allow_inf_nan' and value is False:
|
||||
schema = cs.no_info_after_validator_function(
|
||||
forbid_inf_nan_check,
|
||||
schema,
|
||||
)
|
||||
else:
|
||||
# It's rare that we'd get here, but it's possible if we add a new constraint and forget to handle it
|
||||
# Most constraint errors are caught at runtime during attempted application
|
||||
raise RuntimeError(f"Unable to apply constraint '{constraint}' to schema of type '{schema_type}'")
|
||||
|
||||
for annotation in other_metadata:
|
||||
if (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
|
||||
constraint = at_to_constraint_map[annotation_type]
|
||||
schema = cs.no_info_after_validator_function(
|
||||
partial(get_constraint_validator(constraint), {constraint: getattr(annotation, constraint)}), schema
|
||||
)
|
||||
continue
|
||||
elif isinstance(annotation, (at.Predicate, at.Not)):
|
||||
predicate_name = f'{annotation.func.__qualname__}' if hasattr(annotation.func, '__qualname__') else ''
|
||||
|
||||
def val_func(v: Any) -> Any:
|
||||
predicate_satisfied = annotation.func(v) # noqa: B023
|
||||
|
||||
# annotation.func may also raise an exception, let it pass through
|
||||
if isinstance(annotation, at.Predicate): # noqa: B023
|
||||
if not predicate_satisfied:
|
||||
raise PydanticCustomError(
|
||||
'predicate_failed',
|
||||
f'Predicate {predicate_name} failed', # type: ignore # noqa: B023
|
||||
)
|
||||
else:
|
||||
if predicate_satisfied:
|
||||
raise PydanticCustomError(
|
||||
'not_operation_failed',
|
||||
f'Not of {predicate_name} failed', # type: ignore # noqa: B023
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
schema = cs.no_info_after_validator_function(val_func, schema)
|
||||
else:
|
||||
# ignore any other unknown metadata
|
||||
return None
|
||||
|
||||
if chain_schema_steps:
|
||||
chain_schema_steps = [schema] + chain_schema_steps
|
||||
return cs.chain_schema(chain_schema_steps)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any], list[Any]]:
|
||||
"""Split `annotations` into known metadata and unknown annotations.
|
||||
|
||||
Args:
|
||||
annotations: An iterable of annotations.
|
||||
|
||||
Returns:
|
||||
A tuple contains a dict of known metadata and a list of unknown annotations.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from annotated_types import Gt, Len
|
||||
|
||||
from pydantic._internal._known_annotated_metadata import collect_known_metadata
|
||||
|
||||
print(collect_known_metadata([Gt(1), Len(42), ...]))
|
||||
#> ({'gt': 1, 'min_length': 42}, [Ellipsis])
|
||||
```
|
||||
"""
|
||||
annotations = expand_grouped_metadata(annotations)
|
||||
|
||||
res: dict[str, Any] = {}
|
||||
remaining: list[Any] = []
|
||||
|
||||
for annotation in annotations:
|
||||
# isinstance(annotation, PydanticMetadata) also covers ._fields:_PydanticGeneralMetadata
|
||||
if isinstance(annotation, PydanticMetadata):
|
||||
res.update(annotation.__dict__)
|
||||
# we don't use dataclasses.asdict because that recursively calls asdict on the field values
|
||||
elif (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
|
||||
constraint = at_to_constraint_map[annotation_type]
|
||||
res[constraint] = getattr(annotation, constraint)
|
||||
elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata):
|
||||
# also support PydanticMetadata classes being used without initialisation,
|
||||
# e.g. `Annotated[int, Strict]` as well as `Annotated[int, Strict()]`
|
||||
res.update({k: v for k, v in vars(annotation).items() if not k.startswith('_')})
|
||||
else:
|
||||
remaining.append(annotation)
|
||||
# Nones can sneak in but pydantic-core will reject them
|
||||
# it'd be nice to clean things up so we don't put in None (we probably don't _need_ to, it was just easier)
|
||||
# but this is simple enough to kick that can down the road
|
||||
res = {k: v for k, v in res.items() if v is not None}
|
||||
return res, remaining
|
||||
|
||||
|
||||
def check_metadata(metadata: dict[str, Any], allowed: Iterable[str], source_type: Any) -> None:
|
||||
"""A small utility function to validate that the given metadata can be applied to the target.
|
||||
More than saving lines of code, this gives us a consistent error message for all of our internal implementations.
|
||||
|
||||
Args:
|
||||
metadata: A dict of metadata.
|
||||
allowed: An iterable of allowed metadata.
|
||||
source_type: The source type.
|
||||
|
||||
Raises:
|
||||
TypeError: If there is metadatas that can't be applied on source type.
|
||||
"""
|
||||
unknown = metadata.keys() - set(allowed)
|
||||
if unknown:
|
||||
raise TypeError(
|
||||
f'The following constraints cannot be applied to {source_type!r}: {", ".join([f"{k!r}" for k in unknown])}'
|
||||
)
|
||||
@@ -0,0 +1,194 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Mapping, TypeVar, Union
|
||||
|
||||
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..errors import PydanticErrorCodes, PydanticUserError
|
||||
from ..plugin._schema_validator import PluggableSchemaValidator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..dataclasses import PydanticDataclass
|
||||
from ..main import BaseModel
|
||||
|
||||
|
||||
ValSer = TypeVar('ValSer', bound=Union[SchemaValidator, PluggableSchemaValidator, SchemaSerializer])
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class MockCoreSchema(Mapping[str, Any]):
|
||||
"""Mocker for `pydantic_core.CoreSchema` which optionally attempts to
|
||||
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
|
||||
"""
|
||||
|
||||
__slots__ = '_error_message', '_code', '_attempt_rebuild', '_built_memo'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_message: str,
|
||||
*,
|
||||
code: PydanticErrorCodes,
|
||||
attempt_rebuild: Callable[[], CoreSchema | None] | None = None,
|
||||
) -> None:
|
||||
self._error_message = error_message
|
||||
self._code: PydanticErrorCodes = code
|
||||
self._attempt_rebuild = attempt_rebuild
|
||||
self._built_memo: CoreSchema | None = None
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._get_built().__getitem__(key)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._get_built().__len__()
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return self._get_built().__iter__()
|
||||
|
||||
def _get_built(self) -> CoreSchema:
|
||||
if self._built_memo is not None:
|
||||
return self._built_memo
|
||||
|
||||
if self._attempt_rebuild:
|
||||
schema = self._attempt_rebuild()
|
||||
if schema is not None:
|
||||
self._built_memo = schema
|
||||
return schema
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
|
||||
def rebuild(self) -> CoreSchema | None:
|
||||
self._built_memo = None
|
||||
if self._attempt_rebuild:
|
||||
val_ser = self._attempt_rebuild()
|
||||
if val_ser is not None:
|
||||
return val_ser
|
||||
else:
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
return None
|
||||
|
||||
|
||||
class MockValSer(Generic[ValSer]):
|
||||
"""Mocker for `pydantic_core.SchemaValidator` or `pydantic_core.SchemaSerializer` which optionally attempts to
|
||||
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
|
||||
"""
|
||||
|
||||
__slots__ = '_error_message', '_code', '_val_or_ser', '_attempt_rebuild'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_message: str,
|
||||
*,
|
||||
code: PydanticErrorCodes,
|
||||
val_or_ser: Literal['validator', 'serializer'],
|
||||
attempt_rebuild: Callable[[], ValSer | None] | None = None,
|
||||
) -> None:
|
||||
self._error_message = error_message
|
||||
self._val_or_ser = SchemaValidator if val_or_ser == 'validator' else SchemaSerializer
|
||||
self._code: PydanticErrorCodes = code
|
||||
self._attempt_rebuild = attempt_rebuild
|
||||
|
||||
def __getattr__(self, item: str) -> None:
|
||||
__tracebackhide__ = True
|
||||
if self._attempt_rebuild:
|
||||
val_ser = self._attempt_rebuild()
|
||||
if val_ser is not None:
|
||||
return getattr(val_ser, item)
|
||||
|
||||
# raise an AttributeError if `item` doesn't exist
|
||||
getattr(self._val_or_ser, item)
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
|
||||
def rebuild(self) -> ValSer | None:
|
||||
if self._attempt_rebuild:
|
||||
val_ser = self._attempt_rebuild()
|
||||
if val_ser is not None:
|
||||
return val_ser
|
||||
else:
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
return None
|
||||
|
||||
|
||||
def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = 'all referenced types') -> None:
|
||||
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a model.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
cls_name: Name of the model class, used in error messages
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
"""
|
||||
undefined_type_error_message = (
|
||||
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `{cls_name}.model_rebuild()`.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[type[BaseModel]], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(cls)
|
||||
else:
|
||||
return None
|
||||
|
||||
return handler
|
||||
|
||||
cls.__pydantic_core_schema__ = MockCoreSchema( # type: ignore[assignment]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
|
||||
)
|
||||
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
|
||||
)
|
||||
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
|
||||
)
|
||||
|
||||
|
||||
def set_dataclass_mocks(
|
||||
cls: type[PydanticDataclass], cls_name: str, undefined_name: str = 'all referenced types'
|
||||
) -> None:
|
||||
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
cls_name: Name of the model class, used in error messages
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
"""
|
||||
from ..dataclasses import rebuild_dataclass
|
||||
|
||||
undefined_type_error_message = (
|
||||
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[type[PydanticDataclass]], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(cls)
|
||||
else:
|
||||
return None
|
||||
|
||||
return handler
|
||||
|
||||
cls.__pydantic_core_schema__ = MockCoreSchema( # type: ignore[assignment]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
|
||||
)
|
||||
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
|
||||
)
|
||||
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
|
||||
)
|
||||
@@ -0,0 +1,752 @@
|
||||
"""Private logic for creating models."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import builtins
|
||||
import operator
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABCMeta
|
||||
from functools import lru_cache, partial
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable, Generic, Literal, NoReturn
|
||||
|
||||
import typing_extensions
|
||||
from pydantic_core import PydanticUndefined, SchemaSerializer
|
||||
from typing_extensions import dataclass_transform, deprecated
|
||||
|
||||
from ..errors import PydanticUndefinedAnnotation, PydanticUserError
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
|
||||
from ._config import ConfigWrapper
|
||||
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases, unwrap_wrapped_function
|
||||
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._generics import PydanticGenericMetadata, get_model_typevars_map
|
||||
from ._import_utils import import_cached_base_model, import_cached_field_info
|
||||
from ._mock_val_ser import set_model_mocks
|
||||
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
|
||||
from ._signature import generate_pydantic_signature
|
||||
from ._typing_extra import (
|
||||
_make_forward_ref,
|
||||
eval_type_backport,
|
||||
is_annotated,
|
||||
is_classvar,
|
||||
merge_cls_and_parent_ns,
|
||||
parent_frame_namespace,
|
||||
)
|
||||
from ._utils import ClassAttribute, SafeGetItemProxy
|
||||
from ._validate_call import ValidateCallWrapper
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..fields import Field as PydanticModelField
|
||||
from ..fields import FieldInfo, ModelPrivateAttr
|
||||
from ..fields import PrivateAttr as PydanticModelPrivateAttr
|
||||
from ..main import BaseModel
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
PydanticModelField = object()
|
||||
PydanticModelPrivateAttr = object()
|
||||
|
||||
object_setattr = object.__setattr__
|
||||
|
||||
|
||||
class _ModelNamespaceDict(dict):
|
||||
"""A dictionary subclass that intercepts attribute setting on model classes and
|
||||
warns about overriding of decorators.
|
||||
"""
|
||||
|
||||
def __setitem__(self, k: str, v: object) -> None:
|
||||
existing: Any = self.get(k, None)
|
||||
if existing and v is not existing and isinstance(existing, PydanticDescriptorProxy):
|
||||
warnings.warn(f'`{k}` overrides an existing Pydantic `{existing.decorator_info.decorator_repr}` decorator')
|
||||
|
||||
return super().__setitem__(k, v)
|
||||
|
||||
|
||||
def NoInitField(
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
) -> Any:
|
||||
"""Only for typing purposes. Used as default value of `__pydantic_fields_set__`,
|
||||
`__pydantic_extra__`, `__pydantic_private__`, so they could be ignored when
|
||||
synthesizing the `__init__` signature.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr, NoInitField))
|
||||
class ModelMetaclass(ABCMeta):
|
||||
def __new__(
|
||||
mcs,
|
||||
cls_name: str,
|
||||
bases: tuple[type[Any], ...],
|
||||
namespace: dict[str, Any],
|
||||
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
|
||||
__pydantic_reset_parent_namespace__: bool = True,
|
||||
_create_model_module: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> type:
|
||||
"""Metaclass for creating Pydantic models.
|
||||
|
||||
Args:
|
||||
cls_name: The name of the class to be created.
|
||||
bases: The base classes of the class to be created.
|
||||
namespace: The attribute dictionary of the class to be created.
|
||||
__pydantic_generic_metadata__: Metadata for generic models.
|
||||
__pydantic_reset_parent_namespace__: Reset parent namespace.
|
||||
_create_model_module: The module of the class to be created, if created by `create_model`.
|
||||
**kwargs: Catch-all for any other keyword arguments.
|
||||
|
||||
Returns:
|
||||
The new class created by the metaclass.
|
||||
"""
|
||||
# Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we rely on the fact
|
||||
# that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__`
|
||||
# call we're in the middle of is for the `BaseModel` class.
|
||||
if bases:
|
||||
base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases)
|
||||
|
||||
config_wrapper = ConfigWrapper.for_model(bases, namespace, kwargs)
|
||||
namespace['model_config'] = config_wrapper.config_dict
|
||||
private_attributes = inspect_namespace(
|
||||
namespace, config_wrapper.ignored_types, class_vars, base_field_names
|
||||
)
|
||||
if private_attributes or base_private_attributes:
|
||||
original_model_post_init = get_model_post_init(namespace, bases)
|
||||
if original_model_post_init is not None:
|
||||
# if there are private_attributes and a model_post_init function, we handle both
|
||||
|
||||
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
|
||||
"""We need to both initialize private attributes and call the user-defined model_post_init
|
||||
method.
|
||||
"""
|
||||
init_private_attributes(self, context)
|
||||
original_model_post_init(self, context)
|
||||
|
||||
namespace['model_post_init'] = wrapped_model_post_init
|
||||
else:
|
||||
namespace['model_post_init'] = init_private_attributes
|
||||
|
||||
namespace['__class_vars__'] = class_vars
|
||||
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}
|
||||
|
||||
cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore
|
||||
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
mro = cls.__mro__
|
||||
if Generic in mro and mro.index(Generic) < mro.index(BaseModel):
|
||||
warnings.warn(
|
||||
GenericBeforeBaseModelWarning(
|
||||
'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) '
|
||||
'for pydantic generics to work properly.'
|
||||
),
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False)
|
||||
cls.__pydantic_post_init__ = None if cls.model_post_init is BaseModel.model_post_init else 'model_post_init'
|
||||
|
||||
cls.__pydantic_decorators__ = DecoratorInfos.build(cls)
|
||||
|
||||
# Use the getattr below to grab the __parameters__ from the `typing.Generic` parent class
|
||||
if __pydantic_generic_metadata__:
|
||||
cls.__pydantic_generic_metadata__ = __pydantic_generic_metadata__
|
||||
else:
|
||||
parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ())
|
||||
parameters = getattr(cls, '__parameters__', None) or parent_parameters
|
||||
if parameters and parent_parameters and not all(x in parameters for x in parent_parameters):
|
||||
from ..root_model import RootModelRootType
|
||||
|
||||
missing_parameters = tuple(x for x in parameters if x not in parent_parameters)
|
||||
if RootModelRootType in parent_parameters and RootModelRootType not in parameters:
|
||||
# This is a special case where the user has subclassed `RootModel`, but has not parametrized
|
||||
# RootModel with the generic type identifiers being used. Ex:
|
||||
# class MyModel(RootModel, Generic[T]):
|
||||
# root: T
|
||||
# Should instead just be:
|
||||
# class MyModel(RootModel[T]):
|
||||
# root: T
|
||||
parameters_str = ', '.join([x.__name__ for x in missing_parameters])
|
||||
error_message = (
|
||||
f'{cls.__name__} is a subclass of `RootModel`, but does not include the generic type identifier(s) '
|
||||
f'{parameters_str} in its parameters. '
|
||||
f'You should parametrize RootModel directly, e.g., `class {cls.__name__}(RootModel[{parameters_str}]): ...`.'
|
||||
)
|
||||
else:
|
||||
combined_parameters = parent_parameters + missing_parameters
|
||||
parameters_str = ', '.join([str(x) for x in combined_parameters])
|
||||
generic_type_label = f'typing.Generic[{parameters_str}]'
|
||||
error_message = (
|
||||
f'All parameters must be present on typing.Generic;'
|
||||
f' you should inherit from {generic_type_label}.'
|
||||
)
|
||||
if Generic not in bases: # pragma: no cover
|
||||
# We raise an error here not because it is desirable, but because some cases are mishandled.
|
||||
# It would be nice to remove this error and still have things behave as expected, it's just
|
||||
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
|
||||
# and not returning a typing._GenericAlias from it.
|
||||
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
|
||||
error_message += (
|
||||
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
|
||||
)
|
||||
raise TypeError(error_message)
|
||||
|
||||
cls.__pydantic_generic_metadata__ = {
|
||||
'origin': None,
|
||||
'args': (),
|
||||
'parameters': parameters,
|
||||
}
|
||||
|
||||
cls.__pydantic_complete__ = False # Ensure this specific class gets completed
|
||||
|
||||
# preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487
|
||||
# for attributes not in `new_namespace` (e.g. private attributes)
|
||||
for name, obj in private_attributes.items():
|
||||
obj.__set_name__(cls, name)
|
||||
|
||||
if __pydantic_reset_parent_namespace__:
|
||||
cls.__pydantic_parent_namespace__ = build_lenient_weakvaluedict(parent_frame_namespace())
|
||||
parent_namespace = getattr(cls, '__pydantic_parent_namespace__', None)
|
||||
if isinstance(parent_namespace, dict):
|
||||
parent_namespace = unpack_lenient_weakvaluedict(parent_namespace)
|
||||
|
||||
types_namespace = merge_cls_and_parent_ns(cls, parent_namespace)
|
||||
set_model_fields(cls, bases, config_wrapper, types_namespace)
|
||||
|
||||
if config_wrapper.frozen and '__hash__' not in namespace:
|
||||
set_default_hash_func(cls, bases)
|
||||
|
||||
complete_model_class(
|
||||
cls,
|
||||
cls_name,
|
||||
config_wrapper,
|
||||
raise_errors=False,
|
||||
types_namespace=types_namespace,
|
||||
create_model_module=_create_model_module,
|
||||
)
|
||||
|
||||
# If this is placed before the complete_model_class call above,
|
||||
# the generic computed fields return type is set to PydanticUndefined
|
||||
cls.model_computed_fields = {k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()}
|
||||
|
||||
set_deprecated_descriptors(cls)
|
||||
|
||||
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
|
||||
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
|
||||
# only hit for _proper_ subclasses of BaseModel
|
||||
super(cls, cls).__pydantic_init_subclass__(**kwargs) # type: ignore[misc]
|
||||
return cls
|
||||
else:
|
||||
# These are instance variables, but have been assigned to `NoInitField` to trick the type checker.
|
||||
for instance_slot in '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__':
|
||||
namespace.pop(
|
||||
instance_slot,
|
||||
None, # In case the metaclass is used with a class other than `BaseModel`.
|
||||
)
|
||||
namespace.get('__annotations__', {}).clear()
|
||||
return super().__new__(mcs, cls_name, bases, namespace, **kwargs)
|
||||
|
||||
if not typing.TYPE_CHECKING: # pragma: no branch
|
||||
# We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""This is necessary to keep attribute access working for class attribute access."""
|
||||
private_attributes = self.__dict__.get('__private_attributes__')
|
||||
if private_attributes and item in private_attributes:
|
||||
return private_attributes[item]
|
||||
raise AttributeError(item)
|
||||
|
||||
@classmethod
|
||||
def __prepare__(cls, *args: Any, **kwargs: Any) -> dict[str, object]:
|
||||
return _ModelNamespaceDict()
|
||||
|
||||
def __instancecheck__(self, instance: Any) -> bool:
|
||||
"""Avoid calling ABC _abc_subclasscheck unless we're pretty sure.
|
||||
|
||||
See #3829 and python/cpython#92810
|
||||
"""
|
||||
return hasattr(instance, '__pydantic_validator__') and super().__instancecheck__(instance)
|
||||
|
||||
@staticmethod
|
||||
def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, ModelPrivateAttr]]:
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
field_names: set[str] = set()
|
||||
class_vars: set[str] = set()
|
||||
private_attributes: dict[str, ModelPrivateAttr] = {}
|
||||
for base in bases:
|
||||
if issubclass(base, BaseModel) and base is not BaseModel:
|
||||
# model_fields might not be defined yet in the case of generics, so we use getattr here:
|
||||
field_names.update(getattr(base, 'model_fields', {}).keys())
|
||||
class_vars.update(base.__class_vars__)
|
||||
private_attributes.update(base.__private_attributes__)
|
||||
return field_names, class_vars, private_attributes
|
||||
|
||||
@property
|
||||
@deprecated('The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None)
|
||||
def __fields__(self) -> dict[str, FieldInfo]:
|
||||
warnings.warn(
|
||||
'The `__fields__` attribute is deprecated, use `model_fields` instead.', PydanticDeprecatedSince20
|
||||
)
|
||||
return self.model_fields # type: ignore
|
||||
|
||||
def __dir__(self) -> list[str]:
|
||||
attributes = list(super().__dir__())
|
||||
if '__fields__' in attributes:
|
||||
attributes.remove('__fields__')
|
||||
return attributes
|
||||
|
||||
|
||||
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
|
||||
"""This function is meant to behave like a BaseModel method to initialise private attributes.
|
||||
|
||||
It takes context as an argument since that's what pydantic-core passes when calling it.
|
||||
|
||||
Args:
|
||||
self: The BaseModel instance.
|
||||
context: The context.
|
||||
"""
|
||||
if getattr(self, '__pydantic_private__', None) is None:
|
||||
pydantic_private = {}
|
||||
for name, private_attr in self.__private_attributes__.items():
|
||||
default = private_attr.get_default()
|
||||
if default is not PydanticUndefined:
|
||||
pydantic_private[name] = default
|
||||
object_setattr(self, '__pydantic_private__', pydantic_private)
|
||||
|
||||
|
||||
def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> Callable[..., Any] | None:
|
||||
"""Get the `model_post_init` method from the namespace or the class bases, or `None` if not defined."""
|
||||
if 'model_post_init' in namespace:
|
||||
return namespace['model_post_init']
|
||||
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
model_post_init = get_attribute_from_bases(bases, 'model_post_init')
|
||||
if model_post_init is not BaseModel.model_post_init:
|
||||
return model_post_init
|
||||
|
||||
|
||||
def inspect_namespace( # noqa C901
|
||||
namespace: dict[str, Any],
|
||||
ignored_types: tuple[type[Any], ...],
|
||||
base_class_vars: set[str],
|
||||
base_class_fields: set[str],
|
||||
) -> dict[str, ModelPrivateAttr]:
|
||||
"""Iterate over the namespace and:
|
||||
* gather private attributes
|
||||
* check for items which look like fields but are not (e.g. have no annotation) and warn.
|
||||
|
||||
Args:
|
||||
namespace: The attribute dictionary of the class to be created.
|
||||
ignored_types: A tuple of ignore types.
|
||||
base_class_vars: A set of base class class variables.
|
||||
base_class_fields: A set of base class fields.
|
||||
|
||||
Returns:
|
||||
A dict contains private attributes info.
|
||||
|
||||
Raises:
|
||||
TypeError: If there is a `__root__` field in model.
|
||||
NameError: If private attribute name is invalid.
|
||||
PydanticUserError:
|
||||
- If a field does not have a type annotation.
|
||||
- If a field on base class was overridden by a non-annotated attribute.
|
||||
"""
|
||||
from ..fields import ModelPrivateAttr, PrivateAttr
|
||||
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
all_ignored_types = ignored_types + default_ignored_types()
|
||||
|
||||
private_attributes: dict[str, ModelPrivateAttr] = {}
|
||||
raw_annotations = namespace.get('__annotations__', {})
|
||||
|
||||
if '__root__' in raw_annotations or '__root__' in namespace:
|
||||
raise TypeError("To define root models, use `pydantic.RootModel` rather than a field called '__root__'")
|
||||
|
||||
ignored_names: set[str] = set()
|
||||
for var_name, value in list(namespace.items()):
|
||||
if var_name == 'model_config' or var_name == '__pydantic_extra__':
|
||||
continue
|
||||
elif (
|
||||
isinstance(value, type)
|
||||
and value.__module__ == namespace['__module__']
|
||||
and '__qualname__' in namespace
|
||||
and value.__qualname__.startswith(namespace['__qualname__'])
|
||||
):
|
||||
# `value` is a nested type defined in this namespace; don't error
|
||||
continue
|
||||
elif isinstance(value, all_ignored_types) or value.__class__.__module__ == 'functools':
|
||||
ignored_names.add(var_name)
|
||||
continue
|
||||
elif isinstance(value, ModelPrivateAttr):
|
||||
if var_name.startswith('__'):
|
||||
raise NameError(
|
||||
'Private attributes must not use dunder names;'
|
||||
f' use a single underscore prefix instead of {var_name!r}.'
|
||||
)
|
||||
elif is_valid_field_name(var_name):
|
||||
raise NameError(
|
||||
'Private attributes must not use valid field names;'
|
||||
f' use sunder names, e.g. {"_" + var_name!r} instead of {var_name!r}.'
|
||||
)
|
||||
private_attributes[var_name] = value
|
||||
del namespace[var_name]
|
||||
elif isinstance(value, FieldInfo) and not is_valid_field_name(var_name):
|
||||
suggested_name = var_name.lstrip('_') or 'my_field' # don't suggest '' for all-underscore name
|
||||
raise NameError(
|
||||
f'Fields must not use names with leading underscores;'
|
||||
f' e.g., use {suggested_name!r} instead of {var_name!r}.'
|
||||
)
|
||||
|
||||
elif var_name.startswith('__'):
|
||||
continue
|
||||
elif is_valid_privateattr_name(var_name):
|
||||
if var_name not in raw_annotations or not is_classvar(raw_annotations[var_name]):
|
||||
private_attributes[var_name] = PrivateAttr(default=value)
|
||||
del namespace[var_name]
|
||||
elif var_name in base_class_vars:
|
||||
continue
|
||||
elif var_name not in raw_annotations:
|
||||
if var_name in base_class_fields:
|
||||
raise PydanticUserError(
|
||||
f'Field {var_name!r} defined on a base class was overridden by a non-annotated attribute. '
|
||||
f'All field definitions, including overrides, require a type annotation.',
|
||||
code='model-field-overridden',
|
||||
)
|
||||
elif isinstance(value, FieldInfo):
|
||||
raise PydanticUserError(
|
||||
f'Field {var_name!r} requires a type annotation', code='model-field-missing-annotation'
|
||||
)
|
||||
else:
|
||||
raise PydanticUserError(
|
||||
f'A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a '
|
||||
f'type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this '
|
||||
f"error by annotating it as a `ClassVar` or updating `model_config['ignored_types']`.",
|
||||
code='model-field-missing-annotation',
|
||||
)
|
||||
|
||||
for ann_name, ann_type in raw_annotations.items():
|
||||
if (
|
||||
is_valid_privateattr_name(ann_name)
|
||||
and ann_name not in private_attributes
|
||||
and ann_name not in ignored_names
|
||||
# This condition is a false negative when `ann_type` is stringified,
|
||||
# but it is handled in `set_model_fields`:
|
||||
and not is_classvar(ann_type)
|
||||
and ann_type not in all_ignored_types
|
||||
and getattr(ann_type, '__module__', None) != 'functools'
|
||||
):
|
||||
if isinstance(ann_type, str):
|
||||
# Walking up the frames to get the module namespace where the model is defined
|
||||
# (as the model class wasn't created yet, we unfortunately can't use `cls.__module__`):
|
||||
frame = sys._getframe(2)
|
||||
if frame is not None:
|
||||
try:
|
||||
ann_type = eval_type_backport(
|
||||
_make_forward_ref(ann_type, is_argument=False, is_class=True),
|
||||
globalns=frame.f_globals,
|
||||
localns=frame.f_locals,
|
||||
)
|
||||
except (NameError, TypeError):
|
||||
pass
|
||||
|
||||
if is_annotated(ann_type):
|
||||
_, *metadata = typing_extensions.get_args(ann_type)
|
||||
private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None)
|
||||
if private_attr is not None:
|
||||
private_attributes[ann_name] = private_attr
|
||||
continue
|
||||
private_attributes[ann_name] = PrivateAttr()
|
||||
|
||||
return private_attributes
|
||||
|
||||
|
||||
def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None:
|
||||
base_hash_func = get_attribute_from_bases(bases, '__hash__')
|
||||
new_hash_func = make_hash_func(cls)
|
||||
if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__:
|
||||
# If `__hash__` is some default, we generate a hash function.
|
||||
# It will be `None` if not overridden from BaseModel.
|
||||
# It may be `object.__hash__` if there is another
|
||||
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
|
||||
# It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model.
|
||||
# In the last case we still need a new hash function to account for new `model_fields`.
|
||||
cls.__hash__ = new_hash_func
|
||||
|
||||
|
||||
def make_hash_func(cls: type[BaseModel]) -> Any:
|
||||
getter = operator.itemgetter(*cls.model_fields.keys()) if cls.model_fields else lambda _: 0
|
||||
|
||||
def hash_func(self: Any) -> int:
|
||||
try:
|
||||
return hash(getter(self.__dict__))
|
||||
except KeyError:
|
||||
# In rare cases (such as when using the deprecated copy method), the __dict__ may not contain
|
||||
# all model fields, which is how we can get here.
|
||||
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
|
||||
# and wrapping it in a `try` doesn't slow things down much in the common case.
|
||||
return hash(getter(SafeGetItemProxy(self.__dict__)))
|
||||
|
||||
return hash_func
|
||||
|
||||
|
||||
def set_model_fields(
|
||||
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
|
||||
) -> None:
|
||||
"""Collect and set `cls.model_fields` and `cls.__class_vars__`.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
bases: Parents of the class, generally `cls.__bases__`.
|
||||
config_wrapper: The config wrapper instance.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
"""
|
||||
typevars_map = get_model_typevars_map(cls)
|
||||
fields, class_vars = collect_model_fields(cls, bases, config_wrapper, types_namespace, typevars_map=typevars_map)
|
||||
|
||||
cls.model_fields = fields
|
||||
cls.__class_vars__.update(class_vars)
|
||||
|
||||
for k in class_vars:
|
||||
# Class vars should not be private attributes
|
||||
# We remove them _here_ and not earlier because we rely on inspecting the class to determine its classvars,
|
||||
# but private attributes are determined by inspecting the namespace _prior_ to class creation.
|
||||
# In the case that a classvar with a leading-'_' is defined via a ForwardRef (e.g., when using
|
||||
# `__future__.annotations`), we want to remove the private attribute which was detected _before_ we knew it
|
||||
# evaluated to a classvar
|
||||
|
||||
value = cls.__private_attributes__.pop(k, None)
|
||||
if value is not None and value.default is not PydanticUndefined:
|
||||
setattr(cls, k, value.default)
|
||||
|
||||
|
||||
def complete_model_class(
|
||||
cls: type[BaseModel],
|
||||
cls_name: str,
|
||||
config_wrapper: ConfigWrapper,
|
||||
*,
|
||||
raise_errors: bool = True,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
create_model_module: str | None = None,
|
||||
) -> bool:
|
||||
"""Finish building a model class.
|
||||
|
||||
This logic must be called after class has been created since validation functions must be bound
|
||||
and `get_type_hints` requires a class object.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
cls_name: The model or dataclass name.
|
||||
config_wrapper: The config wrapper instance.
|
||||
raise_errors: Whether to raise errors.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
create_model_module: The module of the class to be created, if created by `create_model`.
|
||||
|
||||
Returns:
|
||||
`True` if the model is successfully completed, else `False`.
|
||||
|
||||
Raises:
|
||||
PydanticUndefinedAnnotation: If `PydanticUndefinedAnnotation` occurs in`__get_pydantic_core_schema__`
|
||||
and `raise_errors=True`.
|
||||
"""
|
||||
typevars_map = get_model_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
types_namespace,
|
||||
typevars_map,
|
||||
)
|
||||
|
||||
handler = CallbackGetCoreSchemaHandler(
|
||||
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
|
||||
gen_schema,
|
||||
ref_mode='unpack',
|
||||
)
|
||||
|
||||
if config_wrapper.defer_build and 'model' in config_wrapper.experimental_defer_build_mode:
|
||||
set_model_mocks(cls, cls_name)
|
||||
return False
|
||||
|
||||
try:
|
||||
schema = cls.__get_pydantic_core_schema__(cls, handler)
|
||||
except PydanticUndefinedAnnotation as e:
|
||||
if raise_errors:
|
||||
raise
|
||||
set_model_mocks(cls, cls_name, f'`{e.name}`')
|
||||
return False
|
||||
|
||||
core_config = config_wrapper.core_config(cls)
|
||||
|
||||
try:
|
||||
schema = gen_schema.clean_schema(schema)
|
||||
except gen_schema.CollectedInvalid:
|
||||
set_model_mocks(cls, cls_name)
|
||||
return False
|
||||
|
||||
# debug(schema)
|
||||
cls.__pydantic_core_schema__ = schema
|
||||
|
||||
cls.__pydantic_validator__ = create_schema_validator(
|
||||
schema,
|
||||
cls,
|
||||
create_model_module or cls.__module__,
|
||||
cls.__qualname__,
|
||||
'create_model' if create_model_module else 'BaseModel',
|
||||
core_config,
|
||||
config_wrapper.plugin_settings,
|
||||
)
|
||||
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
|
||||
cls.__pydantic_complete__ = True
|
||||
|
||||
# set __signature__ attr only for model class, but not for its instances
|
||||
cls.__signature__ = ClassAttribute(
|
||||
'__signature__',
|
||||
generate_pydantic_signature(init=cls.__init__, fields=cls.model_fields, config_wrapper=config_wrapper),
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def set_deprecated_descriptors(cls: type[BaseModel]) -> None:
|
||||
"""Set data descriptors on the class for deprecated fields."""
|
||||
for field, field_info in cls.model_fields.items():
|
||||
if (msg := field_info.deprecation_message) is not None:
|
||||
desc = _DeprecatedFieldDescriptor(msg)
|
||||
desc.__set_name__(cls, field)
|
||||
setattr(cls, field, desc)
|
||||
|
||||
for field, computed_field_info in cls.model_computed_fields.items():
|
||||
if (
|
||||
(msg := computed_field_info.deprecation_message) is not None
|
||||
# Avoid having two warnings emitted:
|
||||
and not hasattr(unwrap_wrapped_function(computed_field_info.wrapped_property), '__deprecated__')
|
||||
):
|
||||
desc = _DeprecatedFieldDescriptor(msg, computed_field_info.wrapped_property)
|
||||
desc.__set_name__(cls, field)
|
||||
setattr(cls, field, desc)
|
||||
|
||||
|
||||
class _DeprecatedFieldDescriptor:
|
||||
"""Data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
|
||||
|
||||
Attributes:
|
||||
msg: The deprecation message to be emitted.
|
||||
wrapped_property: The property instance if the deprecated field is a computed field, or `None`.
|
||||
field_name: The name of the field being deprecated.
|
||||
"""
|
||||
|
||||
field_name: str
|
||||
|
||||
def __init__(self, msg: str, wrapped_property: property | None = None) -> None:
|
||||
self.msg = msg
|
||||
self.wrapped_property = wrapped_property
|
||||
|
||||
def __set_name__(self, cls: type[BaseModel], name: str) -> None:
|
||||
self.field_name = name
|
||||
|
||||
def __get__(self, obj: BaseModel | None, obj_type: type[BaseModel] | None = None) -> Any:
|
||||
if obj is None:
|
||||
raise AttributeError(self.field_name)
|
||||
|
||||
warnings.warn(self.msg, builtins.DeprecationWarning, stacklevel=2)
|
||||
|
||||
if self.wrapped_property is not None:
|
||||
return self.wrapped_property.__get__(obj, obj_type)
|
||||
return obj.__dict__[self.field_name]
|
||||
|
||||
# Defined to take precedence over the instance's dictionary
|
||||
# Note that it will not be called when setting a value on a model instance
|
||||
# as `BaseModel.__setattr__` is defined and takes priority.
|
||||
def __set__(self, obj: Any, value: Any) -> NoReturn:
|
||||
raise AttributeError(self.field_name)
|
||||
|
||||
|
||||
class _PydanticWeakRef:
|
||||
"""Wrapper for `weakref.ref` that enables `pickle` serialization.
|
||||
|
||||
Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related
|
||||
to abstract base classes (`abc.ABC`). This class works around the issue by wrapping
|
||||
`weakref.ref` instead of subclassing it.
|
||||
|
||||
See https://github.com/pydantic/pydantic/issues/6763 for context.
|
||||
|
||||
Semantics:
|
||||
- If not pickled, behaves the same as a `weakref.ref`.
|
||||
- If pickled along with the referenced object, the same `weakref.ref` behavior
|
||||
will be maintained between them after unpickling.
|
||||
- If pickled without the referenced object, after unpickling the underlying
|
||||
reference will be cleared (`__call__` will always return `None`).
|
||||
"""
|
||||
|
||||
def __init__(self, obj: Any):
|
||||
if obj is None:
|
||||
# The object will be `None` upon deserialization if the serialized weakref
|
||||
# had lost its underlying object.
|
||||
self._wr = None
|
||||
else:
|
||||
self._wr = weakref.ref(obj)
|
||||
|
||||
def __call__(self) -> Any:
|
||||
if self._wr is None:
|
||||
return None
|
||||
else:
|
||||
return self._wr()
|
||||
|
||||
def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]:
|
||||
return _PydanticWeakRef, (self(),)
|
||||
|
||||
|
||||
def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Takes an input dictionary, and produces a new value that (invertibly) replaces the values with weakrefs.
|
||||
|
||||
We can't just use a WeakValueDictionary because many types (including int, str, etc.) can't be stored as values
|
||||
in a WeakValueDictionary.
|
||||
|
||||
The `unpack_lenient_weakvaluedict` function can be used to reverse this operation.
|
||||
"""
|
||||
if d is None:
|
||||
return None
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
try:
|
||||
proxy = _PydanticWeakRef(v)
|
||||
except TypeError:
|
||||
proxy = v
|
||||
result[k] = proxy
|
||||
return result
|
||||
|
||||
|
||||
def unpack_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Inverts the transform performed by `build_lenient_weakvaluedict`."""
|
||||
if d is None:
|
||||
return None
|
||||
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, _PydanticWeakRef):
|
||||
v = v()
|
||||
if v is not None:
|
||||
result[k] = v
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def default_ignored_types() -> tuple[type[Any], ...]:
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
return (
|
||||
FunctionType,
|
||||
property,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
PydanticDescriptorProxy,
|
||||
ComputedFieldInfo,
|
||||
ValidateCallWrapper,
|
||||
)
|
||||
118
venv/lib/python3.11/site-packages/pydantic/_internal/_repr.py
Normal file
118
venv/lib/python3.11/site-packages/pydantic/_internal/_repr.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Tools to provide pretty/human-readable display of objects."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import types
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
import typing_extensions
|
||||
|
||||
from . import _typing_extra
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
ReprArgs: typing_extensions.TypeAlias = 'typing.Iterable[tuple[str | None, Any]]'
|
||||
RichReprResult: typing_extensions.TypeAlias = (
|
||||
'typing.Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]'
|
||||
)
|
||||
|
||||
|
||||
class PlainRepr(str):
|
||||
"""String class where repr doesn't include quotes. Useful with Representation when you want to return a string
|
||||
representation of something that is valid (or pseudo-valid) python.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
|
||||
class Representation:
|
||||
# Mixin to provide `__str__`, `__repr__`, and `__pretty__` and `__rich_repr__` methods.
|
||||
# `__pretty__` is used by [devtools](https://python-devtools.helpmanual.io/).
|
||||
# `__rich_repr__` is used by [rich](https://rich.readthedocs.io/en/stable/pretty.html).
|
||||
# (this is not a docstring to avoid adding a docstring to classes which inherit from Representation)
|
||||
|
||||
# we don't want to use a type annotation here as it can break get_type_hints
|
||||
__slots__ = () # type: typing.Collection[str]
|
||||
|
||||
def __repr_args__(self) -> ReprArgs:
|
||||
"""Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
|
||||
|
||||
Can either return:
|
||||
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
|
||||
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
|
||||
"""
|
||||
attrs_names = self.__slots__
|
||||
if not attrs_names and hasattr(self, '__dict__'):
|
||||
attrs_names = self.__dict__.keys()
|
||||
attrs = ((s, getattr(self, s)) for s in attrs_names)
|
||||
return [(a, v) for a, v in attrs if v is not None]
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
"""Name of the instance's class, used in __repr__."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
|
||||
|
||||
def __pretty__(self, fmt: typing.Callable[[Any], Any], **kwargs: Any) -> typing.Generator[Any, None, None]:
|
||||
"""Used by devtools (https://python-devtools.helpmanual.io/) to pretty print objects."""
|
||||
yield self.__repr_name__() + '('
|
||||
yield 1
|
||||
for name, value in self.__repr_args__():
|
||||
if name is not None:
|
||||
yield name + '='
|
||||
yield fmt(value)
|
||||
yield ','
|
||||
yield 0
|
||||
yield -1
|
||||
yield ')'
|
||||
|
||||
def __rich_repr__(self) -> RichReprResult:
|
||||
"""Used by Rich (https://rich.readthedocs.io/en/stable/pretty.html) to pretty print objects."""
|
||||
for name, field_repr in self.__repr_args__():
|
||||
if name is None:
|
||||
yield field_repr
|
||||
else:
|
||||
yield name, field_repr
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr_str__(' ')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
|
||||
|
||||
|
||||
def display_as_type(obj: Any) -> str:
|
||||
"""Pretty representation of a type, should be as close as possible to the original type definition string.
|
||||
|
||||
Takes some logic from `typing._type_repr`.
|
||||
"""
|
||||
if isinstance(obj, types.FunctionType):
|
||||
return obj.__name__
|
||||
elif obj is ...:
|
||||
return '...'
|
||||
elif isinstance(obj, Representation):
|
||||
return repr(obj)
|
||||
elif isinstance(obj, typing_extensions.TypeAliasType):
|
||||
return str(obj)
|
||||
|
||||
if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)):
|
||||
obj = obj.__class__
|
||||
|
||||
if _typing_extra.origin_is_union(typing_extensions.get_origin(obj)):
|
||||
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
|
||||
return f'Union[{args}]'
|
||||
elif isinstance(obj, _typing_extra.WithArgsTypes):
|
||||
if typing_extensions.get_origin(obj) == typing_extensions.Literal:
|
||||
args = ', '.join(map(repr, typing_extensions.get_args(obj)))
|
||||
else:
|
||||
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
|
||||
try:
|
||||
return f'{obj.__qualname__}[{args}]'
|
||||
except AttributeError:
|
||||
return str(obj) # handles TypeAliasType in 3.12
|
||||
elif isinstance(obj, type):
|
||||
return obj.__qualname__
|
||||
else:
|
||||
return repr(obj).replace('typing.', '').replace('typing_extensions.', '')
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Types and utility functions used by various other internal tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..json_schema import GenerateJsonSchema, JsonSchemaValue
|
||||
from ._core_utils import CoreSchemaOrField
|
||||
from ._generate_schema import GenerateSchema
|
||||
|
||||
GetJsonSchemaFunction = Callable[[CoreSchemaOrField, GetJsonSchemaHandler], JsonSchemaValue]
|
||||
HandlerOverride = Callable[[CoreSchemaOrField], JsonSchemaValue]
|
||||
|
||||
|
||||
class GenerateJsonSchemaHandler(GetJsonSchemaHandler):
|
||||
"""JsonSchemaHandler implementation that doesn't do ref unwrapping by default.
|
||||
|
||||
This is used for any Annotated metadata so that we don't end up with conflicting
|
||||
modifications to the definition schema.
|
||||
|
||||
Used internally by Pydantic, please do not rely on this implementation.
|
||||
See `GetJsonSchemaHandler` for the handler API.
|
||||
"""
|
||||
|
||||
def __init__(self, generate_json_schema: GenerateJsonSchema, handler_override: HandlerOverride | None) -> None:
|
||||
self.generate_json_schema = generate_json_schema
|
||||
self.handler = handler_override or generate_json_schema.generate_inner
|
||||
self.mode = generate_json_schema.mode
|
||||
|
||||
def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue:
|
||||
return self.handler(core_schema)
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
|
||||
"""Resolves `$ref` in the json schema.
|
||||
|
||||
This returns the input json schema if there is no `$ref` in json schema.
|
||||
|
||||
Args:
|
||||
maybe_ref_json_schema: The input json schema that may contains `$ref`.
|
||||
|
||||
Returns:
|
||||
Resolved json schema.
|
||||
|
||||
Raises:
|
||||
LookupError: If it can't find the definition for `$ref`.
|
||||
"""
|
||||
if '$ref' not in maybe_ref_json_schema:
|
||||
return maybe_ref_json_schema
|
||||
ref = maybe_ref_json_schema['$ref']
|
||||
json_schema = self.generate_json_schema.get_schema_from_definitions(ref)
|
||||
if json_schema is None:
|
||||
raise LookupError(
|
||||
f'Could not find a ref for {ref}.'
|
||||
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
|
||||
)
|
||||
return json_schema
|
||||
|
||||
|
||||
class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
|
||||
"""Wrapper to use an arbitrary function as a `GetCoreSchemaHandler`.
|
||||
|
||||
Used internally by Pydantic, please do not rely on this implementation.
|
||||
See `GetCoreSchemaHandler` for the handler API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler: Callable[[Any], core_schema.CoreSchema],
|
||||
generate_schema: GenerateSchema,
|
||||
ref_mode: Literal['to-def', 'unpack'] = 'to-def',
|
||||
) -> None:
|
||||
self._handler = handler
|
||||
self._generate_schema = generate_schema
|
||||
self._ref_mode = ref_mode
|
||||
|
||||
def __call__(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
schema = self._handler(source_type)
|
||||
ref = schema.get('ref')
|
||||
if self._ref_mode == 'to-def':
|
||||
if ref is not None:
|
||||
self._generate_schema.defs.definitions[ref] = schema
|
||||
return core_schema.definition_reference_schema(ref)
|
||||
return schema
|
||||
else: # ref_mode = 'unpack
|
||||
return self.resolve_ref_schema(schema)
|
||||
|
||||
def _get_types_namespace(self) -> dict[str, Any] | None:
|
||||
return self._generate_schema._types_namespace
|
||||
|
||||
def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
return self._generate_schema.generate_schema(source_type)
|
||||
|
||||
@property
|
||||
def field_name(self) -> str | None:
|
||||
return self._generate_schema.field_name_stack.get()
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
"""Resolves reference in the core schema.
|
||||
|
||||
Args:
|
||||
maybe_ref_schema: The input core schema that may contains reference.
|
||||
|
||||
Returns:
|
||||
Resolved core schema.
|
||||
|
||||
Raises:
|
||||
LookupError: If it can't find the definition for reference.
|
||||
"""
|
||||
if maybe_ref_schema['type'] == 'definition-ref':
|
||||
ref = maybe_ref_schema['schema_ref']
|
||||
if ref not in self._generate_schema.defs.definitions:
|
||||
raise LookupError(
|
||||
f'Could not find a ref for {ref}.'
|
||||
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
|
||||
)
|
||||
return self._generate_schema.defs.definitions[ref]
|
||||
elif maybe_ref_schema['type'] == 'definitions':
|
||||
return self.resolve_ref_schema(maybe_ref_schema['schema'])
|
||||
return maybe_ref_schema
|
||||
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticOmit, core_schema
|
||||
|
||||
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.Deque: collections.deque,
|
||||
collections.deque: collections.deque,
|
||||
list: list,
|
||||
typing.List: list,
|
||||
set: set,
|
||||
typing.AbstractSet: set,
|
||||
typing.Set: set,
|
||||
frozenset: frozenset,
|
||||
typing.FrozenSet: frozenset,
|
||||
typing.Sequence: list,
|
||||
typing.MutableSequence: list,
|
||||
typing.MutableSet: set,
|
||||
# this doesn't handle subclasses of these
|
||||
# parametrized typing.Set creates one of these
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
}
|
||||
|
||||
|
||||
def serialize_sequence_via_list(
|
||||
v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
|
||||
) -> Any:
|
||||
items: list[Any] = []
|
||||
|
||||
mapped_origin = SEQUENCE_ORIGIN_MAP.get(type(v), None)
|
||||
if mapped_origin is None:
|
||||
# we shouldn't hit this branch, should probably add a serialization error or something
|
||||
return v
|
||||
|
||||
for index, item in enumerate(v):
|
||||
try:
|
||||
v = handler(item, index)
|
||||
except PydanticOmit:
|
||||
pass
|
||||
else:
|
||||
items.append(v)
|
||||
|
||||
if info.mode_is_json():
|
||||
return items
|
||||
else:
|
||||
return mapped_origin(items)
|
||||
@@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from inspect import Parameter, Signature, signature
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from ._config import ConfigWrapper
|
||||
from ._utils import is_valid_identifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..fields import FieldInfo
|
||||
|
||||
|
||||
def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
|
||||
"""Extract the correct name to use for the field when generating a signature.
|
||||
|
||||
Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
|
||||
First priority is given to the alias, then the validation_alias, then the field name.
|
||||
|
||||
Args:
|
||||
field_name: The name of the field
|
||||
field_info: The corresponding FieldInfo object.
|
||||
|
||||
Returns:
|
||||
The correct name to use when generating a signature.
|
||||
"""
|
||||
if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias):
|
||||
return field_info.alias
|
||||
if isinstance(field_info.validation_alias, str) and is_valid_identifier(field_info.validation_alias):
|
||||
return field_info.validation_alias
|
||||
|
||||
return field_name
|
||||
|
||||
|
||||
def _process_param_defaults(param: Parameter) -> Parameter:
|
||||
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
|
||||
|
||||
Args:
|
||||
param (Parameter): The parameter
|
||||
|
||||
Returns:
|
||||
Parameter: The custom processed parameter
|
||||
"""
|
||||
from ..fields import FieldInfo
|
||||
|
||||
param_default = param.default
|
||||
if isinstance(param_default, FieldInfo):
|
||||
annotation = param.annotation
|
||||
# Replace the annotation if appropriate
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if annotation == 'Any':
|
||||
annotation = Any
|
||||
|
||||
# Replace the field default
|
||||
default = param_default.default
|
||||
if default is PydanticUndefined:
|
||||
if param_default.default_factory is PydanticUndefined:
|
||||
default = Signature.empty
|
||||
else:
|
||||
# this is used by dataclasses to indicate a factory exists:
|
||||
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
|
||||
return param.replace(
|
||||
annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
|
||||
)
|
||||
return param
|
||||
|
||||
|
||||
def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
|
||||
init: Callable[..., None],
|
||||
fields: dict[str, FieldInfo],
|
||||
config_wrapper: ConfigWrapper,
|
||||
) -> dict[str, Parameter]:
|
||||
"""Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
|
||||
from itertools import islice
|
||||
|
||||
present_params = signature(init).parameters.values()
|
||||
merged_params: dict[str, Parameter] = {}
|
||||
var_kw = None
|
||||
use_var_kw = False
|
||||
|
||||
for param in islice(present_params, 1, None): # skip self arg
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if fields.get(param.name):
|
||||
# exclude params with init=False
|
||||
if getattr(fields[param.name], 'init', True) is False:
|
||||
continue
|
||||
param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
|
||||
if param.annotation == 'Any':
|
||||
param = param.replace(annotation=Any)
|
||||
if param.kind is param.VAR_KEYWORD:
|
||||
var_kw = param
|
||||
continue
|
||||
merged_params[param.name] = param
|
||||
|
||||
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
|
||||
allow_names = config_wrapper.populate_by_name
|
||||
for field_name, field in fields.items():
|
||||
# when alias is a str it should be used for signature generation
|
||||
param_name = _field_name_for_signature(field_name, field)
|
||||
|
||||
if field_name in merged_params or param_name in merged_params:
|
||||
continue
|
||||
|
||||
if not is_valid_identifier(param_name):
|
||||
if allow_names:
|
||||
param_name = field_name
|
||||
else:
|
||||
use_var_kw = True
|
||||
continue
|
||||
|
||||
kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
|
||||
merged_params[param_name] = Parameter(
|
||||
param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs
|
||||
)
|
||||
|
||||
if config_wrapper.extra == 'allow':
|
||||
use_var_kw = True
|
||||
|
||||
if var_kw and use_var_kw:
|
||||
# Make sure the parameter for extra kwargs
|
||||
# does not have the same name as a field
|
||||
default_model_signature = [
|
||||
('self', Parameter.POSITIONAL_ONLY),
|
||||
('data', Parameter.VAR_KEYWORD),
|
||||
]
|
||||
if [(p.name, p.kind) for p in present_params] == default_model_signature:
|
||||
# if this is the standard model signature, use extra_data as the extra args name
|
||||
var_kw_name = 'extra_data'
|
||||
else:
|
||||
# else start from var_kw
|
||||
var_kw_name = var_kw.name
|
||||
|
||||
# generate a name that's definitely unique
|
||||
while var_kw_name in fields:
|
||||
var_kw_name += '_'
|
||||
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
|
||||
|
||||
return merged_params
|
||||
|
||||
|
||||
def generate_pydantic_signature(
|
||||
init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper, is_dataclass: bool = False
|
||||
) -> Signature:
|
||||
"""Generate signature for a pydantic BaseModel or dataclass.
|
||||
|
||||
Args:
|
||||
init: The class init.
|
||||
fields: The model fields.
|
||||
config_wrapper: The config wrapper instance.
|
||||
is_dataclass: Whether the model is a dataclass.
|
||||
|
||||
Returns:
|
||||
The dataclass/BaseModel subclass signature.
|
||||
"""
|
||||
merged_params = _generate_signature_parameters(init, fields, config_wrapper)
|
||||
|
||||
if is_dataclass:
|
||||
merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
|
||||
|
||||
return Signature(parameters=list(merged_params.values()), return_annotation=None)
|
||||
@@ -0,0 +1,403 @@
|
||||
"""Logic for generating pydantic-core schemas for standard library types.
|
||||
|
||||
Import of this module is deferred since it contains imports of many standard library modules.
|
||||
"""
|
||||
|
||||
# TODO: eventually, we'd like to move all of the types handled here to have pydantic-core validators
|
||||
# so that we can avoid this annotation injection and just use the standard pydantic-core schema generation
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import dataclasses
|
||||
import os
|
||||
import typing
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Iterable, Tuple, TypeVar
|
||||
|
||||
import typing_extensions
|
||||
from pydantic_core import (
|
||||
CoreSchema,
|
||||
PydanticCustomError,
|
||||
core_schema,
|
||||
)
|
||||
from typing_extensions import get_args, get_origin
|
||||
|
||||
from pydantic._internal._serializers import serialize_sequence_via_list
|
||||
from pydantic.errors import PydanticSchemaGenerationError
|
||||
from pydantic.types import Strict
|
||||
|
||||
from ..json_schema import JsonSchemaValue
|
||||
from . import _known_annotated_metadata, _typing_extra
|
||||
from ._import_utils import import_cached_field_info
|
||||
from ._internal_dataclass import slots_true
|
||||
from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._generate_schema import GenerateSchema
|
||||
|
||||
StdSchemaFunction = Callable[[GenerateSchema, type[Any]], core_schema.CoreSchema]
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class InnerSchemaValidator:
|
||||
"""Use a fixed CoreSchema, avoiding interference from outward annotations."""
|
||||
|
||||
core_schema: CoreSchema
|
||||
js_schema: JsonSchemaValue | None = None
|
||||
js_core_schema: CoreSchema | None = None
|
||||
js_schema_update: JsonSchemaValue | None = None
|
||||
|
||||
def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
if self.js_schema is not None:
|
||||
return self.js_schema
|
||||
js_schema = handler(self.js_core_schema or self.core_schema)
|
||||
if self.js_schema_update is not None:
|
||||
js_schema.update(self.js_schema_update)
|
||||
return js_schema
|
||||
|
||||
def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
return self.core_schema
|
||||
|
||||
|
||||
def path_schema_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any]
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
import pathlib
|
||||
|
||||
orig_source_type: Any = get_origin(source_type) or source_type
|
||||
if (
|
||||
(source_type_args := get_args(source_type))
|
||||
and orig_source_type is os.PathLike
|
||||
and source_type_args[0] not in {str, bytes, Any}
|
||||
):
|
||||
return None
|
||||
|
||||
if orig_source_type not in {
|
||||
os.PathLike,
|
||||
pathlib.Path,
|
||||
pathlib.PurePath,
|
||||
pathlib.PosixPath,
|
||||
pathlib.PurePosixPath,
|
||||
pathlib.PureWindowsPath,
|
||||
}:
|
||||
return None
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.STR_CONSTRAINTS, orig_source_type)
|
||||
|
||||
is_first_arg_byte = source_type_args and source_type_args[0] is bytes
|
||||
construct_path = pathlib.PurePath if orig_source_type is os.PathLike else orig_source_type
|
||||
constrained_schema = (
|
||||
core_schema.bytes_schema(**metadata) if is_first_arg_byte else core_schema.str_schema(**metadata)
|
||||
)
|
||||
|
||||
def path_validator(input_value: str | bytes) -> os.PathLike[Any]: # type: ignore
|
||||
try:
|
||||
if is_first_arg_byte:
|
||||
if isinstance(input_value, bytes):
|
||||
try:
|
||||
input_value = input_value.decode()
|
||||
except UnicodeDecodeError as e:
|
||||
raise PydanticCustomError('bytes_type', 'Input must be valid bytes') from e
|
||||
else:
|
||||
raise PydanticCustomError('bytes_type', 'Input must be bytes')
|
||||
elif not isinstance(input_value, str):
|
||||
raise PydanticCustomError('path_type', 'Input is not a valid path')
|
||||
|
||||
return construct_path(input_value)
|
||||
except TypeError as e:
|
||||
raise PydanticCustomError('path_type', 'Input is not a valid path') from e
|
||||
|
||||
instance_schema = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(path_validator, constrained_schema),
|
||||
python_schema=core_schema.is_instance_schema(orig_source_type),
|
||||
)
|
||||
|
||||
strict: bool | None = None
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, Strict):
|
||||
strict = annotation.strict
|
||||
|
||||
schema = core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.union_schema(
|
||||
[
|
||||
instance_schema,
|
||||
core_schema.no_info_after_validator_function(path_validator, constrained_schema),
|
||||
],
|
||||
custom_error_type='path_type',
|
||||
custom_error_message=f'Input is not a valid path for {orig_source_type}',
|
||||
strict=True,
|
||||
),
|
||||
strict_schema=instance_schema,
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
return (
|
||||
orig_source_type,
|
||||
[
|
||||
InnerSchemaValidator(schema, js_core_schema=constrained_schema, js_schema_update={'format': 'path'}),
|
||||
*remaining_annotations,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def deque_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, maxlen: None | int
|
||||
) -> collections.deque[Any]:
|
||||
if isinstance(input_value, collections.deque):
|
||||
maxlens = [v for v in (input_value.maxlen, maxlen) if v is not None]
|
||||
if maxlens:
|
||||
maxlen = min(maxlens)
|
||||
return collections.deque(handler(input_value), maxlen=maxlen)
|
||||
else:
|
||||
return collections.deque(handler(input_value), maxlen=maxlen)
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class DequeValidator:
|
||||
item_source_type: type[Any]
|
||||
metadata: dict[str, Any]
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
if self.item_source_type is Any:
|
||||
items_schema = None
|
||||
else:
|
||||
items_schema = handler.generate_schema(self.item_source_type)
|
||||
|
||||
# if we have a MaxLen annotation might as well set that as the default maxlen on the deque
|
||||
# this lets us re-use existing metadata annotations to let users set the maxlen on a dequeue
|
||||
# that e.g. comes from JSON
|
||||
coerce_instance_wrap = partial(
|
||||
core_schema.no_info_wrap_validator_function,
|
||||
partial(deque_validator, maxlen=self.metadata.get('max_length', None)),
|
||||
)
|
||||
|
||||
# we have to use a lax list schema here, because we need to validate the deque's
|
||||
# items via a list schema, but it's ok if the deque itself is not a list
|
||||
metadata_with_strict_override = {**self.metadata, 'strict': False}
|
||||
constrained_schema = core_schema.list_schema(items_schema, **metadata_with_strict_override)
|
||||
|
||||
check_instance = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.list_schema(),
|
||||
python_schema=core_schema.is_instance_schema(collections.deque),
|
||||
)
|
||||
|
||||
serialization = core_schema.wrap_serializer_function_ser_schema(
|
||||
serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True
|
||||
)
|
||||
|
||||
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
|
||||
|
||||
if self.metadata.get('strict', False):
|
||||
schema = strict
|
||||
else:
|
||||
lax = coerce_instance_wrap(constrained_schema)
|
||||
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
|
||||
schema['serialization'] = serialization
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def deque_schema_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any]
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
args = get_args(source_type)
|
||||
|
||||
if not args:
|
||||
args = typing.cast(Tuple[Any], (Any,))
|
||||
elif len(args) != 1:
|
||||
raise ValueError('Expected deque to have exactly 1 generic parameter')
|
||||
|
||||
item_source_type = args[0]
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
|
||||
|
||||
return (source_type, [DequeValidator(item_source_type, metadata), *remaining_annotations])
|
||||
|
||||
|
||||
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.DefaultDict: collections.defaultdict,
|
||||
collections.defaultdict: collections.defaultdict,
|
||||
collections.OrderedDict: collections.OrderedDict,
|
||||
typing_extensions.OrderedDict: collections.OrderedDict,
|
||||
dict: dict,
|
||||
typing.Dict: dict,
|
||||
collections.Counter: collections.Counter,
|
||||
typing.Counter: collections.Counter,
|
||||
# this doesn't handle subclasses of these
|
||||
typing.Mapping: dict,
|
||||
typing.MutableMapping: dict,
|
||||
# parametrized typing.{Mutable}Mapping creates one of these
|
||||
collections.abc.MutableMapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
}
|
||||
|
||||
|
||||
def defaultdict_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
|
||||
) -> collections.defaultdict[Any, Any]:
|
||||
if isinstance(input_value, collections.defaultdict):
|
||||
default_factory = input_value.default_factory
|
||||
return collections.defaultdict(default_factory, handler(input_value))
|
||||
else:
|
||||
return collections.defaultdict(default_default_factory, handler(input_value))
|
||||
|
||||
|
||||
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
|
||||
def infer_default() -> Callable[[], Any]:
|
||||
allowed_default_types: dict[Any, Any] = {
|
||||
typing.Tuple: tuple,
|
||||
tuple: tuple,
|
||||
collections.abc.Sequence: tuple,
|
||||
collections.abc.MutableSequence: list,
|
||||
typing.List: list,
|
||||
list: list,
|
||||
typing.Sequence: list,
|
||||
typing.Set: set,
|
||||
set: set,
|
||||
typing.MutableSet: set,
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
typing.MutableMapping: dict,
|
||||
typing.Mapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
collections.abc.MutableMapping: dict,
|
||||
float: float,
|
||||
int: int,
|
||||
str: str,
|
||||
bool: bool,
|
||||
}
|
||||
values_type_origin = get_origin(values_source_type) or values_source_type
|
||||
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
|
||||
if isinstance(values_type_origin, TypeVar):
|
||||
|
||||
def type_var_default_factory() -> None:
|
||||
raise RuntimeError(
|
||||
'Generic defaultdict cannot be used without a concrete value type or an'
|
||||
' explicit default factory, ' + instructions
|
||||
)
|
||||
|
||||
return type_var_default_factory
|
||||
elif values_type_origin not in allowed_default_types:
|
||||
# a somewhat subjective set of types that have reasonable default values
|
||||
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
|
||||
raise PydanticSchemaGenerationError(
|
||||
f'Unable to infer a default factory for keys of type {values_source_type}.'
|
||||
f' Only {allowed_msg} are supported, other types require an explicit default factory'
|
||||
' ' + instructions
|
||||
)
|
||||
return allowed_default_types[values_type_origin]
|
||||
|
||||
# Assume Annotated[..., Field(...)]
|
||||
if _typing_extra.is_annotated(values_source_type):
|
||||
field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
|
||||
else:
|
||||
field_info = None
|
||||
if field_info and field_info.default_factory:
|
||||
default_default_factory = field_info.default_factory
|
||||
else:
|
||||
default_default_factory = infer_default()
|
||||
return default_default_factory
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class MappingValidator:
|
||||
mapped_origin: type[Any]
|
||||
keys_source_type: type[Any]
|
||||
values_source_type: type[Any]
|
||||
min_length: int | None = None
|
||||
max_length: int | None = None
|
||||
strict: bool = False
|
||||
|
||||
def serialize_mapping_via_dict(self, v: Any, handler: core_schema.SerializerFunctionWrapHandler) -> Any:
|
||||
return handler(v)
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
if self.keys_source_type is Any:
|
||||
keys_schema = None
|
||||
else:
|
||||
keys_schema = handler.generate_schema(self.keys_source_type)
|
||||
if self.values_source_type is Any:
|
||||
values_schema = None
|
||||
else:
|
||||
values_schema = handler.generate_schema(self.values_source_type)
|
||||
|
||||
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
|
||||
|
||||
if self.mapped_origin is dict:
|
||||
schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
|
||||
else:
|
||||
constrained_schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
|
||||
check_instance = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.dict_schema(),
|
||||
python_schema=core_schema.is_instance_schema(self.mapped_origin),
|
||||
)
|
||||
|
||||
if self.mapped_origin is collections.defaultdict:
|
||||
default_default_factory = get_defaultdict_default_default_factory(self.values_source_type)
|
||||
coerce_instance_wrap = partial(
|
||||
core_schema.no_info_wrap_validator_function,
|
||||
partial(defaultdict_validator, default_default_factory=default_default_factory),
|
||||
)
|
||||
else:
|
||||
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
|
||||
|
||||
serialization = core_schema.wrap_serializer_function_ser_schema(
|
||||
self.serialize_mapping_via_dict,
|
||||
schema=core_schema.dict_schema(
|
||||
keys_schema or core_schema.any_schema(), values_schema or core_schema.any_schema()
|
||||
),
|
||||
info_arg=False,
|
||||
)
|
||||
|
||||
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
|
||||
|
||||
if metadata.get('strict', False):
|
||||
schema = strict
|
||||
else:
|
||||
lax = coerce_instance_wrap(constrained_schema)
|
||||
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
|
||||
schema['serialization'] = serialization
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def mapping_like_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any]
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
origin: Any = get_origin(source_type)
|
||||
|
||||
mapped_origin = MAPPING_ORIGIN_MAP.get(origin, None) if origin else MAPPING_ORIGIN_MAP.get(source_type, None)
|
||||
if mapped_origin is None:
|
||||
return None
|
||||
|
||||
args = get_args(source_type)
|
||||
|
||||
if not args:
|
||||
args = typing.cast(Tuple[Any, Any], (Any, Any))
|
||||
elif mapped_origin is collections.Counter:
|
||||
# a single generic
|
||||
if len(args) != 1:
|
||||
raise ValueError('Expected Counter to have exactly 1 generic parameter')
|
||||
args = (args[0], int) # keys are always an int
|
||||
elif len(args) != 2:
|
||||
raise ValueError('Expected mapping to have exactly 2 generic parameters')
|
||||
|
||||
keys_source_type, values_source_type = args
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
|
||||
|
||||
return (
|
||||
source_type,
|
||||
[
|
||||
MappingValidator(mapped_origin, keys_source_type, values_source_type, **metadata),
|
||||
*remaining_annotations,
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,581 @@
|
||||
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap python's typing module."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from types import GetSetDescriptorType
|
||||
from typing import TYPE_CHECKING, Any, Final, Iterable
|
||||
|
||||
from typing_extensions import Annotated, Literal, TypeAliasType, TypeGuard, deprecated, get_args, get_origin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._dataclasses import StandardDataclass
|
||||
|
||||
try:
|
||||
from typing import _TypingBase # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
from typing import _Final as _TypingBase # type: ignore[attr-defined]
|
||||
|
||||
typing_base = _TypingBase
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
TypingGenericAlias = ()
|
||||
else:
|
||||
from typing import GenericAlias as TypingGenericAlias # type: ignore
|
||||
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import NotRequired, Required
|
||||
else:
|
||||
from typing import NotRequired, Required # noqa: F401
|
||||
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
|
||||
def origin_is_union(tp: type[Any] | None) -> bool:
|
||||
return tp is typing.Union
|
||||
|
||||
WithArgsTypes = (TypingGenericAlias,)
|
||||
|
||||
else:
|
||||
|
||||
def origin_is_union(tp: type[Any] | None) -> bool:
|
||||
return tp is typing.Union or tp is types.UnionType
|
||||
|
||||
WithArgsTypes = typing._GenericAlias, types.GenericAlias, types.UnionType # type: ignore[attr-defined]
|
||||
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
NoneType = type(None)
|
||||
EllipsisType = type(Ellipsis)
|
||||
else:
|
||||
from types import NoneType as NoneType
|
||||
|
||||
|
||||
LITERAL_TYPES: set[Any] = {Literal}
|
||||
if hasattr(typing, 'Literal'):
|
||||
LITERAL_TYPES.add(typing.Literal) # type: ignore
|
||||
|
||||
# Check if `deprecated` is a type to prevent errors when using typing_extensions < 4.9.0
|
||||
DEPRECATED_TYPES: tuple[Any, ...] = (deprecated,) if isinstance(deprecated, type) else ()
|
||||
if hasattr(warnings, 'deprecated'):
|
||||
DEPRECATED_TYPES = (*DEPRECATED_TYPES, warnings.deprecated) # type: ignore
|
||||
|
||||
NONE_TYPES: tuple[Any, ...] = (None, NoneType, *(tp[None] for tp in LITERAL_TYPES))
|
||||
|
||||
|
||||
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
|
||||
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
return type_ in NONE_TYPES
|
||||
|
||||
|
||||
def is_callable_type(type_: type[Any]) -> bool:
|
||||
return type_ is Callable or get_origin(type_) is Callable
|
||||
|
||||
|
||||
def is_literal_type(type_: type[Any]) -> bool:
|
||||
return Literal is not None and get_origin(type_) in LITERAL_TYPES
|
||||
|
||||
|
||||
def is_deprecated_instance(instance: Any) -> TypeGuard[deprecated]:
|
||||
return isinstance(instance, DEPRECATED_TYPES)
|
||||
|
||||
|
||||
def literal_values(type_: type[Any]) -> tuple[Any, ...]:
|
||||
return get_args(type_)
|
||||
|
||||
|
||||
def all_literal_values(type_: type[Any]) -> list[Any]:
|
||||
"""This method is used to retrieve all Literal values as
|
||||
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
|
||||
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`.
|
||||
"""
|
||||
if not is_literal_type(type_):
|
||||
return [type_]
|
||||
|
||||
values = literal_values(type_)
|
||||
return [x for value in values for x in all_literal_values(value)]
|
||||
|
||||
|
||||
def is_annotated(ann_type: Any) -> bool:
|
||||
return get_origin(ann_type) is Annotated
|
||||
|
||||
|
||||
def annotated_type(type_: Any) -> Any | None:
|
||||
return get_args(type_)[0] if is_annotated(type_) else None
|
||||
|
||||
|
||||
def is_namedtuple(type_: type[Any]) -> bool:
|
||||
"""Check if a given class is a named tuple.
|
||||
It can be either a `typing.NamedTuple` or `collections.namedtuple`.
|
||||
"""
|
||||
from ._utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
|
||||
|
||||
|
||||
test_new_type = typing.NewType('test_new_type', str)
|
||||
|
||||
|
||||
def is_new_type(type_: type[Any]) -> bool:
|
||||
"""Check whether type_ was created using typing.NewType.
|
||||
|
||||
Can't use isinstance because it fails <3.10.
|
||||
"""
|
||||
return isinstance(type_, test_new_type.__class__) and hasattr(type_, '__supertype__') # type: ignore[arg-type]
|
||||
|
||||
|
||||
classvar_re = re.compile(r'(\w+\.)?ClassVar\[')
|
||||
|
||||
|
||||
def _check_classvar(v: type[Any] | None) -> bool:
|
||||
return v is not None and v.__class__ is typing.ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
|
||||
|
||||
|
||||
def is_classvar(ann_type: type[Any]) -> bool:
|
||||
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
|
||||
return True
|
||||
|
||||
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
|
||||
# forward references, see #3679
|
||||
if ann_type.__class__ == typing.ForwardRef and classvar_re.match(ann_type.__forward_arg__):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _check_finalvar(v: type[Any] | None) -> bool:
|
||||
"""Check if a given type is a `typing.Final` type."""
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
|
||||
|
||||
|
||||
def is_finalvar(ann_type: Any) -> bool:
|
||||
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
|
||||
|
||||
|
||||
def parent_frame_namespace(*, parent_depth: int = 2, force: bool = False) -> dict[str, Any] | None:
|
||||
"""We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the
|
||||
global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope
|
||||
and suggestion at the end of the next comment by @gvanrossum.
|
||||
|
||||
WARNING 1: it matters exactly where this is called. By default, this function will build a namespace from the
|
||||
parent of where it is called.
|
||||
|
||||
WARNING 2: this only looks in the parent namespace, not other parents since (AFAIK) there's no way to collect a
|
||||
dict of exactly what's in scope. Using `f_back` would work sometimes but would be very wrong and confusing in many
|
||||
other cases. See https://discuss.python.org/t/is-there-a-way-to-access-parent-nested-namespaces/20659.
|
||||
|
||||
There are some cases where we want to force fetching the parent namespace, ex: during a `model_rebuild` call.
|
||||
In this case, we want both the namespace of the class' module, if applicable, and the parent namespace of the
|
||||
module where the rebuild is called.
|
||||
|
||||
In other cases, like during initial schema build, if a class is defined at the top module level, we don't need to
|
||||
fetch that module's namespace, because the class' __module__ attribute can be used to access the parent namespace.
|
||||
This is done in `_typing_extra.get_module_ns_of`. Thus, there's no need to cache the parent frame namespace in this case.
|
||||
"""
|
||||
frame = sys._getframe(parent_depth)
|
||||
|
||||
# note, we don't copy frame.f_locals here (or during the last return call), because we don't expect the namespace to be modified down the line
|
||||
# if this becomes a problem, we could implement some sort of frozen mapping structure to enforce this
|
||||
if force:
|
||||
return frame.f_locals
|
||||
|
||||
# if either of the following conditions are true, the class is defined at the top module level
|
||||
# to better understand why we need both of these checks, see
|
||||
# https://github.com/pydantic/pydantic/pull/10113#discussion_r1714981531
|
||||
if frame.f_back is None or frame.f_code.co_name == '<module>':
|
||||
return None
|
||||
|
||||
return frame.f_locals
|
||||
|
||||
|
||||
def get_module_ns_of(obj: Any) -> dict[str, Any]:
|
||||
"""Get the namespace of the module where the object is defined.
|
||||
|
||||
Caution: this function does not return a copy of the module namespace, so it should not be mutated.
|
||||
The burden of enforcing this is on the caller.
|
||||
"""
|
||||
module_name = getattr(obj, '__module__', None)
|
||||
if module_name:
|
||||
try:
|
||||
return sys.modules[module_name].__dict__
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
def merge_cls_and_parent_ns(cls: type[Any], parent_namespace: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
ns = get_module_ns_of(cls).copy()
|
||||
if parent_namespace is not None:
|
||||
ns.update(parent_namespace)
|
||||
ns[cls.__name__] = cls
|
||||
return ns
|
||||
|
||||
|
||||
def get_cls_type_hints_lenient(
|
||||
obj: Any, globalns: dict[str, Any] | None = None, mro: Iterable[type[Any]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Collect annotations from a class, including those from parent classes.
|
||||
|
||||
Unlike `typing.get_type_hints`, this function will not error if a forward reference is not resolvable.
|
||||
"""
|
||||
hints = {}
|
||||
if mro is None:
|
||||
mro = reversed(obj.__mro__)
|
||||
for base in mro:
|
||||
ann = base.__dict__.get('__annotations__')
|
||||
localns = dict(vars(base))
|
||||
if ann is not None and ann is not GetSetDescriptorType:
|
||||
for name, value in ann.items():
|
||||
hints[name] = eval_type_lenient(value, globalns, localns)
|
||||
return hints
|
||||
|
||||
|
||||
def eval_type_lenient(value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None) -> Any:
|
||||
"""Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved."""
|
||||
if value is None:
|
||||
value = NoneType
|
||||
elif isinstance(value, str):
|
||||
value = _make_forward_ref(value, is_argument=False, is_class=True)
|
||||
|
||||
try:
|
||||
return eval_type_backport(value, globalns, localns)
|
||||
except NameError:
|
||||
# the point of this function is to be tolerant to this case
|
||||
return value
|
||||
|
||||
|
||||
def eval_type_backport(
|
||||
value: Any,
|
||||
globalns: dict[str, Any] | None = None,
|
||||
localns: dict[str, Any] | None = None,
|
||||
type_params: tuple[Any] | None = None,
|
||||
) -> Any:
|
||||
"""An enhanced version of `typing._eval_type` which will fall back to using the `eval_type_backport`
|
||||
package if it's installed to let older Python versions use newer typing constructs.
|
||||
|
||||
Specifically, this transforms `X | Y` into `typing.Union[X, Y]` and `list[X]` into `typing.List[X]`
|
||||
(as well as all the types made generic in PEP 585) if the original syntax is not supported in the
|
||||
current Python version.
|
||||
|
||||
This function will also display a helpful error if the value passed fails to evaluate.
|
||||
"""
|
||||
try:
|
||||
return _eval_type_backport(value, globalns, localns, type_params)
|
||||
except TypeError as e:
|
||||
if 'Unable to evaluate type annotation' in str(e):
|
||||
raise
|
||||
|
||||
# If it is a `TypeError` and value isn't a `ForwardRef`, it would have failed during annotation definition.
|
||||
# Thus we assert here for type checking purposes:
|
||||
assert isinstance(value, typing.ForwardRef)
|
||||
|
||||
message = f'Unable to evaluate type annotation {value.__forward_arg__!r}.'
|
||||
if sys.version_info >= (3, 11):
|
||||
e.add_note(message)
|
||||
raise
|
||||
else:
|
||||
raise TypeError(message) from e
|
||||
|
||||
|
||||
def _eval_type_backport(
|
||||
value: Any,
|
||||
globalns: dict[str, Any] | None = None,
|
||||
localns: dict[str, Any] | None = None,
|
||||
type_params: tuple[Any] | None = None,
|
||||
) -> Any:
|
||||
try:
|
||||
return _eval_type(value, globalns, localns, type_params)
|
||||
except TypeError as e:
|
||||
if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)):
|
||||
raise
|
||||
|
||||
try:
|
||||
from eval_type_backport import eval_type_backport
|
||||
except ImportError:
|
||||
raise TypeError(
|
||||
f'Unable to evaluate type annotation {value.__forward_arg__!r}. If you are making use '
|
||||
'of the new typing syntax (unions using `|` since Python 3.10 or builtins subscripting '
|
||||
'since Python 3.9), you should either replace the use of new syntax with the existing '
|
||||
'`typing` constructs or install the `eval_type_backport` package.'
|
||||
) from e
|
||||
|
||||
return eval_type_backport(value, globalns, localns, try_default=False)
|
||||
|
||||
|
||||
def _eval_type(
|
||||
value: Any,
|
||||
globalns: dict[str, Any] | None = None,
|
||||
localns: dict[str, Any] | None = None,
|
||||
type_params: tuple[Any] | None = None,
|
||||
) -> Any:
|
||||
if sys.version_info >= (3, 13):
|
||||
return typing._eval_type( # type: ignore
|
||||
value, globalns, localns, type_params=type_params
|
||||
)
|
||||
else:
|
||||
return typing._eval_type( # type: ignore
|
||||
value, globalns, localns
|
||||
)
|
||||
|
||||
|
||||
def is_backport_fixable_error(e: TypeError) -> bool:
|
||||
msg = str(e)
|
||||
|
||||
return (
|
||||
sys.version_info < (3, 10)
|
||||
and msg.startswith('unsupported operand type(s) for |: ')
|
||||
or sys.version_info < (3, 9)
|
||||
and "' object is not subscriptable" in msg
|
||||
)
|
||||
|
||||
|
||||
def get_function_type_hints(
|
||||
function: Callable[..., Any], *, include_keys: set[str] | None = None, types_namespace: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Like `typing.get_type_hints`, but doesn't convert `X` to `Optional[X]` if the default value is `None`, also
|
||||
copes with `partial`.
|
||||
"""
|
||||
try:
|
||||
if isinstance(function, partial):
|
||||
annotations = function.func.__annotations__
|
||||
else:
|
||||
annotations = function.__annotations__
|
||||
except AttributeError:
|
||||
type_hints = get_type_hints(function)
|
||||
if isinstance(function, type):
|
||||
# `type[...]` is a callable, which returns an instance of itself.
|
||||
# At some point, we might even look into the return type of `__new__`
|
||||
# if it returns something else.
|
||||
type_hints.setdefault('return', function)
|
||||
return type_hints
|
||||
|
||||
globalns = get_module_ns_of(function)
|
||||
type_hints = {}
|
||||
type_params: tuple[Any] = getattr(function, '__type_params__', ()) # type: ignore
|
||||
for name, value in annotations.items():
|
||||
if include_keys is not None and name not in include_keys:
|
||||
continue
|
||||
if value is None:
|
||||
value = NoneType
|
||||
elif isinstance(value, str):
|
||||
value = _make_forward_ref(value)
|
||||
|
||||
type_hints[name] = eval_type_backport(value, globalns, types_namespace, type_params)
|
||||
|
||||
return type_hints
|
||||
|
||||
|
||||
if sys.version_info < (3, 9, 8) or (3, 10) <= sys.version_info < (3, 10, 1):
|
||||
|
||||
def _make_forward_ref(
|
||||
arg: Any,
|
||||
is_argument: bool = True,
|
||||
*,
|
||||
is_class: bool = False,
|
||||
) -> typing.ForwardRef:
|
||||
"""Wrapper for ForwardRef that accounts for the `is_class` argument missing in older versions.
|
||||
The `module` argument is omitted as it breaks <3.9.8, =3.10.0 and isn't used in the calls below.
|
||||
|
||||
See https://github.com/python/cpython/pull/28560 for some background.
|
||||
The backport happened on 3.9.8, see:
|
||||
https://github.com/pydantic/pydantic/discussions/6244#discussioncomment-6275458,
|
||||
and on 3.10.1 for the 3.10 branch, see:
|
||||
https://github.com/pydantic/pydantic/issues/6912
|
||||
|
||||
Implemented as EAFP with memory.
|
||||
"""
|
||||
return typing.ForwardRef(arg, is_argument)
|
||||
|
||||
else:
|
||||
_make_forward_ref = typing.ForwardRef
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
get_type_hints = typing.get_type_hints
|
||||
|
||||
else:
|
||||
"""
|
||||
For older versions of python, we have a custom implementation of `get_type_hints` which is a close as possible to
|
||||
the implementation in CPython 3.10.8.
|
||||
"""
|
||||
|
||||
@typing.no_type_check
|
||||
def get_type_hints( # noqa: C901
|
||||
obj: Any,
|
||||
globalns: dict[str, Any] | None = None,
|
||||
localns: dict[str, Any] | None = None,
|
||||
include_extras: bool = False,
|
||||
) -> dict[str, Any]: # pragma: no cover
|
||||
"""Taken verbatim from python 3.10.8 unchanged, except:
|
||||
* type annotations of the function definition above.
|
||||
* prefixing `typing.` where appropriate
|
||||
* Use `_make_forward_ref` instead of `typing.ForwardRef` to handle the `is_class` argument.
|
||||
|
||||
https://github.com/python/cpython/blob/aaaf5174241496afca7ce4d4584570190ff972fe/Lib/typing.py#L1773-L1875
|
||||
|
||||
DO NOT CHANGE THIS METHOD UNLESS ABSOLUTELY NECESSARY.
|
||||
======================================================
|
||||
|
||||
Return type hints for an object.
|
||||
|
||||
This is often the same as obj.__annotations__, but it handles
|
||||
forward references encoded as string literals, adds Optional[t] if a
|
||||
default value equal to None is set and recursively replaces all
|
||||
'Annotated[T, ...]' with 'T' (unless 'include_extras=True').
|
||||
|
||||
The argument may be a module, class, method, or function. The annotations
|
||||
are returned as a dictionary. For classes, annotations include also
|
||||
inherited members.
|
||||
|
||||
TypeError is raised if the argument is not of a type that can contain
|
||||
annotations, and an empty dictionary is returned if no annotations are
|
||||
present.
|
||||
|
||||
BEWARE -- the behavior of globalns and localns is counterintuitive
|
||||
(unless you are familiar with how eval() and exec() work). The
|
||||
search order is locals first, then globals.
|
||||
|
||||
- If no dict arguments are passed, an attempt is made to use the
|
||||
globals from obj (or the respective module's globals for classes),
|
||||
and these are also used as the locals. If the object does not appear
|
||||
to have globals, an empty dictionary is used. For classes, the search
|
||||
order is globals first then locals.
|
||||
|
||||
- If one dict argument is passed, it is used for both globals and
|
||||
locals.
|
||||
|
||||
- If two dict arguments are passed, they specify globals and
|
||||
locals, respectively.
|
||||
"""
|
||||
if getattr(obj, '__no_type_check__', None):
|
||||
return {}
|
||||
# Classes require a special treatment.
|
||||
if isinstance(obj, type):
|
||||
hints = {}
|
||||
for base in reversed(obj.__mro__):
|
||||
if globalns is None:
|
||||
base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {})
|
||||
else:
|
||||
base_globals = globalns
|
||||
ann = base.__dict__.get('__annotations__', {})
|
||||
if isinstance(ann, types.GetSetDescriptorType):
|
||||
ann = {}
|
||||
base_locals = dict(vars(base)) if localns is None else localns
|
||||
if localns is None and globalns is None:
|
||||
# This is surprising, but required. Before Python 3.10,
|
||||
# get_type_hints only evaluated the globalns of
|
||||
# a class. To maintain backwards compatibility, we reverse
|
||||
# the globalns and localns order so that eval() looks into
|
||||
# *base_globals* first rather than *base_locals*.
|
||||
# This only affects ForwardRefs.
|
||||
base_globals, base_locals = base_locals, base_globals
|
||||
for name, value in ann.items():
|
||||
if value is None:
|
||||
value = type(None)
|
||||
if isinstance(value, str):
|
||||
value = _make_forward_ref(value, is_argument=False, is_class=True)
|
||||
|
||||
value = eval_type_backport(value, base_globals, base_locals)
|
||||
hints[name] = value
|
||||
if not include_extras and hasattr(typing, '_strip_annotations'):
|
||||
return {
|
||||
k: typing._strip_annotations(t) # type: ignore
|
||||
for k, t in hints.items()
|
||||
}
|
||||
else:
|
||||
return hints
|
||||
|
||||
if globalns is None:
|
||||
if isinstance(obj, types.ModuleType):
|
||||
globalns = obj.__dict__
|
||||
else:
|
||||
nsobj = obj
|
||||
# Find globalns for the unwrapped object.
|
||||
while hasattr(nsobj, '__wrapped__'):
|
||||
nsobj = nsobj.__wrapped__
|
||||
globalns = getattr(nsobj, '__globals__', {})
|
||||
if localns is None:
|
||||
localns = globalns
|
||||
elif localns is None:
|
||||
localns = globalns
|
||||
hints = getattr(obj, '__annotations__', None)
|
||||
if hints is None:
|
||||
# Return empty annotations for something that _could_ have them.
|
||||
if isinstance(obj, typing._allowed_types): # type: ignore
|
||||
return {}
|
||||
else:
|
||||
raise TypeError(f'{obj!r} is not a module, class, method, ' 'or function.')
|
||||
defaults = typing._get_defaults(obj) # type: ignore
|
||||
hints = dict(hints)
|
||||
for name, value in hints.items():
|
||||
if value is None:
|
||||
value = type(None)
|
||||
if isinstance(value, str):
|
||||
# class-level forward refs were handled above, this must be either
|
||||
# a module-level annotation or a function argument annotation
|
||||
|
||||
value = _make_forward_ref(
|
||||
value,
|
||||
is_argument=not isinstance(obj, types.ModuleType),
|
||||
is_class=False,
|
||||
)
|
||||
value = eval_type_backport(value, globalns, localns)
|
||||
if name in defaults and defaults[name] is None:
|
||||
value = typing.Optional[value]
|
||||
hints[name] = value
|
||||
return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
|
||||
|
||||
|
||||
def is_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
# The dataclasses.is_dataclass function doesn't seem to provide TypeGuard functionality,
|
||||
# so I created this convenience function
|
||||
return dataclasses.is_dataclass(_cls)
|
||||
|
||||
|
||||
def origin_is_type_alias_type(origin: Any) -> TypeGuard[TypeAliasType]:
|
||||
return isinstance(origin, TypeAliasType)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
def is_generic_alias(type_: type[Any]) -> bool:
|
||||
return isinstance(type_, (types.GenericAlias, typing._GenericAlias)) # type: ignore[attr-defined]
|
||||
|
||||
else:
|
||||
|
||||
def is_generic_alias(type_: type[Any]) -> bool:
|
||||
return isinstance(type_, typing._GenericAlias) # type: ignore
|
||||
|
||||
|
||||
def is_self_type(tp: Any) -> bool:
|
||||
"""Check if a given class is a Self type (from `typing` or `typing_extensions`)"""
|
||||
return isinstance(tp, typing_base) and getattr(tp, '_name', None) == 'Self'
|
||||
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
def is_zoneinfo_type(tp: Any) -> bool:
|
||||
"""Check if a give class is a zone_info.ZoneInfo type"""
|
||||
return tp is ZoneInfo
|
||||
|
||||
else:
|
||||
|
||||
def is_zoneinfo_type(tp: Any) -> bool:
|
||||
return False
|
||||
363
venv/lib/python3.11/site-packages/pydantic/_internal/_utils.py
Normal file
363
venv/lib/python3.11/site-packages/pydantic/_internal/_utils.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Bucket of reusable internal utilities.
|
||||
|
||||
This should be reduced as much as possible with functions only used in one place, moved to that place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import keyword
|
||||
import typing
|
||||
import weakref
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from itertools import zip_longest
|
||||
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
|
||||
from typing import Any, Mapping, TypeVar
|
||||
|
||||
from typing_extensions import TypeAlias, TypeGuard
|
||||
|
||||
from . import _repr, _typing_extra
|
||||
from ._import_utils import import_cached_base_model
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]'
|
||||
AbstractSetIntStr: TypeAlias = 'typing.AbstractSet[int] | typing.AbstractSet[str]'
|
||||
from ..main import BaseModel
|
||||
|
||||
|
||||
# these are types that are returned unchanged by deepcopy
|
||||
IMMUTABLE_NON_COLLECTIONS_TYPES: set[type[Any]] = {
|
||||
int,
|
||||
float,
|
||||
complex,
|
||||
str,
|
||||
bool,
|
||||
bytes,
|
||||
type,
|
||||
_typing_extra.NoneType,
|
||||
FunctionType,
|
||||
BuiltinFunctionType,
|
||||
LambdaType,
|
||||
weakref.ref,
|
||||
CodeType,
|
||||
# note: including ModuleType will differ from behaviour of deepcopy by not producing error.
|
||||
# It might be not a good idea in general, but considering that this function used only internally
|
||||
# against default values of fields, this will allow to actually have a field with module as default value
|
||||
ModuleType,
|
||||
NotImplemented.__class__,
|
||||
Ellipsis.__class__,
|
||||
}
|
||||
|
||||
# these are types that if empty, might be copied with simple copy() instead of deepcopy()
|
||||
BUILTIN_COLLECTIONS: set[type[Any]] = {
|
||||
list,
|
||||
set,
|
||||
tuple,
|
||||
frozenset,
|
||||
dict,
|
||||
OrderedDict,
|
||||
defaultdict,
|
||||
deque,
|
||||
}
|
||||
|
||||
|
||||
def sequence_like(v: Any) -> bool:
|
||||
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
|
||||
|
||||
|
||||
def lenient_isinstance(o: Any, class_or_tuple: type[Any] | tuple[type[Any], ...] | None) -> bool: # pragma: no cover
|
||||
try:
|
||||
return isinstance(o, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
|
||||
except TypeError:
|
||||
if isinstance(cls, _typing_extra.WithArgsTypes):
|
||||
return False
|
||||
raise # pragma: no cover
|
||||
|
||||
|
||||
def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]:
|
||||
"""Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking,
|
||||
unlike raw calls to lenient_issubclass.
|
||||
"""
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
return lenient_issubclass(cls, BaseModel) and cls is not BaseModel
|
||||
|
||||
|
||||
def is_valid_identifier(identifier: str) -> bool:
|
||||
"""Checks that a string is a valid identifier and not a Python keyword.
|
||||
:param identifier: The identifier to test.
|
||||
:return: True if the identifier is valid.
|
||||
"""
|
||||
return identifier.isidentifier() and not keyword.iskeyword(identifier)
|
||||
|
||||
|
||||
KeyType = TypeVar('KeyType')
|
||||
|
||||
|
||||
def deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]:
|
||||
updated_mapping = mapping.copy()
|
||||
for updating_mapping in updating_mappings:
|
||||
for k, v in updating_mapping.items():
|
||||
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
|
||||
updated_mapping[k] = deep_update(updated_mapping[k], v)
|
||||
else:
|
||||
updated_mapping[k] = v
|
||||
return updated_mapping
|
||||
|
||||
|
||||
def update_not_none(mapping: dict[Any, Any], **update: Any) -> None:
|
||||
mapping.update({k: v for k, v in update.items() if v is not None})
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def unique_list(
|
||||
input_list: list[T] | tuple[T, ...],
|
||||
*,
|
||||
name_factory: typing.Callable[[T], str] = str,
|
||||
) -> list[T]:
|
||||
"""Make a list unique while maintaining order.
|
||||
We update the list if another one with the same name is set
|
||||
(e.g. model validator overridden in subclass).
|
||||
"""
|
||||
result: list[T] = []
|
||||
result_names: list[str] = []
|
||||
for v in input_list:
|
||||
v_name = name_factory(v)
|
||||
if v_name not in result_names:
|
||||
result_names.append(v_name)
|
||||
result.append(v)
|
||||
else:
|
||||
result[result_names.index(v_name)] = v
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ValueItems(_repr.Representation):
|
||||
"""Class for more convenient calculation of excluded or included fields on values."""
|
||||
|
||||
__slots__ = ('_items', '_type')
|
||||
|
||||
def __init__(self, value: Any, items: AbstractSetIntStr | MappingIntStrAny) -> None:
|
||||
items = self._coerce_items(items)
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
items = self._normalize_indexes(items, len(value)) # type: ignore
|
||||
|
||||
self._items: MappingIntStrAny = items # type: ignore
|
||||
|
||||
def is_excluded(self, item: Any) -> bool:
|
||||
"""Check if item is fully excluded.
|
||||
|
||||
:param item: key or index of a value
|
||||
"""
|
||||
return self.is_true(self._items.get(item))
|
||||
|
||||
def is_included(self, item: Any) -> bool:
|
||||
"""Check if value is contained in self._items.
|
||||
|
||||
:param item: key or index of value
|
||||
"""
|
||||
return item in self._items
|
||||
|
||||
def for_element(self, e: int | str) -> AbstractSetIntStr | MappingIntStrAny | None:
|
||||
""":param e: key or index of element on value
|
||||
:return: raw values for element if self._items is dict and contain needed element
|
||||
"""
|
||||
item = self._items.get(e) # type: ignore
|
||||
return item if not self.is_true(item) else None
|
||||
|
||||
def _normalize_indexes(self, items: MappingIntStrAny, v_length: int) -> dict[int | str, Any]:
|
||||
""":param items: dict or set of indexes which will be normalized
|
||||
:param v_length: length of sequence indexes of which will be
|
||||
|
||||
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
|
||||
{0: True, 2: True, 3: True}
|
||||
>>> self._normalize_indexes({'__all__': True}, 4)
|
||||
{0: True, 1: True, 2: True, 3: True}
|
||||
"""
|
||||
normalized_items: dict[int | str, Any] = {}
|
||||
all_items = None
|
||||
for i, v in items.items():
|
||||
if not (isinstance(v, typing.Mapping) or isinstance(v, typing.AbstractSet) or self.is_true(v)):
|
||||
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
|
||||
if i == '__all__':
|
||||
all_items = self._coerce_value(v)
|
||||
continue
|
||||
if not isinstance(i, int):
|
||||
raise TypeError(
|
||||
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
|
||||
'expected integer keys or keyword "__all__"'
|
||||
)
|
||||
normalized_i = v_length + i if i < 0 else i
|
||||
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
|
||||
|
||||
if not all_items:
|
||||
return normalized_items
|
||||
if self.is_true(all_items):
|
||||
for i in range(v_length):
|
||||
normalized_items.setdefault(i, ...)
|
||||
return normalized_items
|
||||
for i in range(v_length):
|
||||
normalized_item = normalized_items.setdefault(i, {})
|
||||
if not self.is_true(normalized_item):
|
||||
normalized_items[i] = self.merge(all_items, normalized_item)
|
||||
return normalized_items
|
||||
|
||||
@classmethod
|
||||
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
|
||||
"""Merge a `base` item with an `override` item.
|
||||
|
||||
Both `base` and `override` are converted to dictionaries if possible.
|
||||
Sets are converted to dictionaries with the sets entries as keys and
|
||||
Ellipsis as values.
|
||||
|
||||
Each key-value pair existing in `base` is merged with `override`,
|
||||
while the rest of the key-value pairs are updated recursively with this function.
|
||||
|
||||
Merging takes place based on the "union" of keys if `intersect` is
|
||||
set to `False` (default) and on the intersection of keys if
|
||||
`intersect` is set to `True`.
|
||||
"""
|
||||
override = cls._coerce_value(override)
|
||||
base = cls._coerce_value(base)
|
||||
if override is None:
|
||||
return base
|
||||
if cls.is_true(base) or base is None:
|
||||
return override
|
||||
if cls.is_true(override):
|
||||
return base if intersect else override
|
||||
|
||||
# intersection or union of keys while preserving ordering:
|
||||
if intersect:
|
||||
merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
|
||||
else:
|
||||
merge_keys = list(base) + [k for k in override if k not in base]
|
||||
|
||||
merged: dict[int | str, Any] = {}
|
||||
for k in merge_keys:
|
||||
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
|
||||
if merged_item is not None:
|
||||
merged[k] = merged_item
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _coerce_items(items: AbstractSetIntStr | MappingIntStrAny) -> MappingIntStrAny:
|
||||
if isinstance(items, typing.Mapping):
|
||||
pass
|
||||
elif isinstance(items, typing.AbstractSet):
|
||||
items = dict.fromkeys(items, ...) # type: ignore
|
||||
else:
|
||||
class_name = getattr(items, '__class__', '???')
|
||||
raise TypeError(f'Unexpected type of exclude value {class_name}')
|
||||
return items # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _coerce_value(cls, value: Any) -> Any:
|
||||
if value is None or cls.is_true(value):
|
||||
return value
|
||||
return cls._coerce_items(value)
|
||||
|
||||
@staticmethod
|
||||
def is_true(v: Any) -> bool:
|
||||
return v is True or v is ...
|
||||
|
||||
def __repr_args__(self) -> _repr.ReprArgs:
|
||||
return [(None, self._items)]
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def ClassAttribute(name: str, value: T) -> T: ...
|
||||
|
||||
else:
|
||||
|
||||
class ClassAttribute:
|
||||
"""Hide class attribute from its instances."""
|
||||
|
||||
__slots__ = 'name', 'value'
|
||||
|
||||
def __init__(self, name: str, value: Any) -> None:
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __get__(self, instance: Any, owner: type[Any]) -> None:
|
||||
if instance is None:
|
||||
return self.value
|
||||
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
|
||||
|
||||
|
||||
Obj = TypeVar('Obj')
|
||||
|
||||
|
||||
def smart_deepcopy(obj: Obj) -> Obj:
|
||||
"""Return type as is for immutable built-in types
|
||||
Use obj.copy() for built-in empty collections
|
||||
Use copy.deepcopy() for non-empty collections and unknown objects.
|
||||
"""
|
||||
obj_type = obj.__class__
|
||||
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
|
||||
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
|
||||
try:
|
||||
if not obj and obj_type in BUILTIN_COLLECTIONS:
|
||||
# faster way for empty collections, no need to copy its members
|
||||
return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method # type: ignore
|
||||
except (TypeError, ValueError, RuntimeError):
|
||||
# do we really dare to catch ALL errors? Seems a bit risky
|
||||
pass
|
||||
|
||||
return deepcopy(obj) # slowest way when we actually might need a deepcopy
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
|
||||
"""Check that the items of `left` are the same objects as those in `right`.
|
||||
|
||||
>>> a, b = object(), object()
|
||||
>>> all_identical([a, b, a], [a, b, a])
|
||||
True
|
||||
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
|
||||
False
|
||||
"""
|
||||
for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
|
||||
if left_item is not right_item:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SafeGetItemProxy:
|
||||
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
|
||||
|
||||
This makes is safe to use in `operator.itemgetter` when some keys may be missing
|
||||
"""
|
||||
|
||||
# Define __slots__manually for performances
|
||||
# @dataclasses.dataclass() only support slots=True in python>=3.10
|
||||
__slots__ = ('wrapped',)
|
||||
|
||||
wrapped: Mapping[str, Any]
|
||||
|
||||
def __getitem__(self, key: str, /) -> Any:
|
||||
return self.wrapped.get(key, _SENTINEL)
|
||||
|
||||
# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
|
||||
# https://github.com/python/mypy/issues/13713
|
||||
# https://github.com/python/typeshed/pull/8785
|
||||
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def __contains__(self, key: str, /) -> bool:
|
||||
return self.wrapped.__contains__(key)
|
||||
@@ -0,0 +1,99 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import inspect
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
import pydantic_core
|
||||
|
||||
from ..config import ConfigDict
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from . import _generate_schema, _typing_extra
|
||||
from ._config import ConfigWrapper
|
||||
|
||||
|
||||
class ValidateCallWrapper:
|
||||
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
|
||||
|
||||
__slots__ = (
|
||||
'__pydantic_validator__',
|
||||
'__name__',
|
||||
'__qualname__',
|
||||
'__annotations__',
|
||||
'__dict__', # required for __module__
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function: Callable[..., Any],
|
||||
config: ConfigDict | None,
|
||||
validate_return: bool,
|
||||
namespace: dict[str, Any] | None,
|
||||
):
|
||||
if isinstance(function, partial):
|
||||
func = function.func
|
||||
schema_type = func
|
||||
self.__name__ = f'partial({func.__name__})'
|
||||
self.__qualname__ = f'partial({func.__qualname__})'
|
||||
self.__module__ = func.__module__
|
||||
else:
|
||||
schema_type = function
|
||||
self.__name__ = function.__name__
|
||||
self.__qualname__ = function.__qualname__
|
||||
self.__module__ = function.__module__
|
||||
|
||||
global_ns = _typing_extra.get_module_ns_of(function)
|
||||
# TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||
# specifically, we shouldn't be pumping the namespace full of type_params
|
||||
# when we take namespace and type_params arguments in eval_type_backport
|
||||
type_params = getattr(schema_type, '__type_params__', ())
|
||||
namespace = {
|
||||
**{param.__name__: param for param in type_params},
|
||||
**(global_ns or {}),
|
||||
**(namespace or {}),
|
||||
}
|
||||
config_wrapper = ConfigWrapper(config)
|
||||
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
|
||||
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
|
||||
core_config = config_wrapper.core_config(self)
|
||||
|
||||
self.__pydantic_validator__ = create_schema_validator(
|
||||
schema,
|
||||
schema_type,
|
||||
self.__module__,
|
||||
self.__qualname__,
|
||||
'validate_call',
|
||||
core_config,
|
||||
config_wrapper.plugin_settings,
|
||||
)
|
||||
|
||||
if validate_return:
|
||||
signature = inspect.signature(function)
|
||||
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
|
||||
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
|
||||
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
|
||||
validator = create_schema_validator(
|
||||
schema,
|
||||
schema_type,
|
||||
self.__module__,
|
||||
self.__qualname__,
|
||||
'validate_call',
|
||||
core_config,
|
||||
config_wrapper.plugin_settings,
|
||||
)
|
||||
if inspect.iscoroutinefunction(function):
|
||||
|
||||
async def return_val_wrapper(aw: Awaitable[Any]) -> None:
|
||||
return validator.validate_python(await aw)
|
||||
|
||||
self.__return_pydantic_validator__ = return_val_wrapper
|
||||
else:
|
||||
self.__return_pydantic_validator__ = validator.validate_python
|
||||
else:
|
||||
self.__return_pydantic_validator__ = None
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
|
||||
if self.__return_pydantic_validator__:
|
||||
return self.__return_pydantic_validator__(res)
|
||||
return res
|
||||
@@ -0,0 +1,314 @@
|
||||
"""Validator functions for standard library types.
|
||||
|
||||
Import of this module is deferred since it contains imports of many standard library modules.
|
||||
"""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import math
|
||||
import re
|
||||
import typing
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from typing import Any, Callable
|
||||
|
||||
from pydantic_core import PydanticCustomError, core_schema
|
||||
from pydantic_core._pydantic_core import PydanticKnownError
|
||||
from pydantic_core.core_schema import ErrorType
|
||||
|
||||
|
||||
def sequence_validator(
|
||||
input_value: typing.Sequence[Any],
|
||||
/,
|
||||
validator: core_schema.ValidatorFunctionWrapHandler,
|
||||
) -> typing.Sequence[Any]:
|
||||
"""Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
|
||||
value_type = type(input_value)
|
||||
|
||||
# We don't accept any plain string as a sequence
|
||||
# Relevant issue: https://github.com/pydantic/pydantic/issues/5595
|
||||
if issubclass(value_type, (str, bytes)):
|
||||
raise PydanticCustomError(
|
||||
'sequence_str',
|
||||
"'{type_name}' instances are not allowed as a Sequence value",
|
||||
{'type_name': value_type.__name__},
|
||||
)
|
||||
|
||||
# TODO: refactor sequence validation to validate with either a list or a tuple
|
||||
# schema, depending on the type of the value.
|
||||
# Additionally, we should be able to remove one of either this validator or the
|
||||
# SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic).
|
||||
# Effectively, a refactor for sequence validation is needed.
|
||||
if value_type is tuple:
|
||||
input_value = list(input_value)
|
||||
|
||||
v_list = validator(input_value)
|
||||
|
||||
# the rest of the logic is just re-creating the original type from `v_list`
|
||||
if value_type is list:
|
||||
return v_list
|
||||
elif issubclass(value_type, range):
|
||||
# return the list as we probably can't re-create the range
|
||||
return v_list
|
||||
elif value_type is tuple:
|
||||
return tuple(v_list)
|
||||
else:
|
||||
# best guess at how to re-create the original type, more custom construction logic might be required
|
||||
return value_type(v_list) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def import_string(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return _import_string_logic(value)
|
||||
except ImportError as e:
|
||||
raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
|
||||
else:
|
||||
# otherwise we just return the value and let the next validator do the rest of the work
|
||||
return value
|
||||
|
||||
|
||||
def _import_string_logic(dotted_path: str) -> Any:
|
||||
"""Inspired by uvicorn — dotted paths should include a colon before the final item if that item is not a module.
|
||||
(This is necessary to distinguish between a submodule and an attribute when there is a conflict.).
|
||||
|
||||
If the dotted path does not include a colon and the final item is not a valid module, importing as an attribute
|
||||
rather than a submodule will be attempted automatically.
|
||||
|
||||
So, for example, the following values of `dotted_path` result in the following returned values:
|
||||
* 'collections': <module 'collections'>
|
||||
* 'collections.abc': <module 'collections.abc'>
|
||||
* 'collections.abc:Mapping': <class 'collections.abc.Mapping'>
|
||||
* `collections.abc.Mapping`: <class 'collections.abc.Mapping'> (though this is a bit slower than the previous line)
|
||||
|
||||
An error will be raised under any of the following scenarios:
|
||||
* `dotted_path` contains more than one colon (e.g., 'collections:abc:Mapping')
|
||||
* the substring of `dotted_path` before the colon is not a valid module in the environment (e.g., '123:Mapping')
|
||||
* the substring of `dotted_path` after the colon is not an attribute of the module (e.g., 'collections:abc123')
|
||||
"""
|
||||
from importlib import import_module
|
||||
|
||||
components = dotted_path.strip().split(':')
|
||||
if len(components) > 2:
|
||||
raise ImportError(f"Import strings should have at most one ':'; received {dotted_path!r}")
|
||||
|
||||
module_path = components[0]
|
||||
if not module_path:
|
||||
raise ImportError(f'Import strings should have a nonempty module name; received {dotted_path!r}')
|
||||
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
except ModuleNotFoundError as e:
|
||||
if '.' in module_path:
|
||||
# Check if it would be valid if the final item was separated from its module with a `:`
|
||||
maybe_module_path, maybe_attribute = dotted_path.strip().rsplit('.', 1)
|
||||
try:
|
||||
return _import_string_logic(f'{maybe_module_path}:{maybe_attribute}')
|
||||
except ImportError:
|
||||
pass
|
||||
raise ImportError(f'No module named {module_path!r}') from e
|
||||
raise e
|
||||
|
||||
if len(components) > 1:
|
||||
attribute = components[1]
|
||||
try:
|
||||
return getattr(module, attribute)
|
||||
except AttributeError as e:
|
||||
raise ImportError(f'cannot import name {attribute!r} from {module_path!r}') from e
|
||||
else:
|
||||
return module
|
||||
|
||||
|
||||
def pattern_either_validator(input_value: Any, /) -> typing.Pattern[Any]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
return input_value
|
||||
elif isinstance(input_value, (str, bytes)):
|
||||
# todo strict mode
|
||||
return compile_pattern(input_value) # type: ignore
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
def pattern_str_validator(input_value: Any, /) -> typing.Pattern[str]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
if isinstance(input_value.pattern, str):
|
||||
return input_value
|
||||
else:
|
||||
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
|
||||
elif isinstance(input_value, str):
|
||||
return compile_pattern(input_value)
|
||||
elif isinstance(input_value, bytes):
|
||||
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
def pattern_bytes_validator(input_value: Any, /) -> typing.Pattern[bytes]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
if isinstance(input_value.pattern, bytes):
|
||||
return input_value
|
||||
else:
|
||||
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
|
||||
elif isinstance(input_value, bytes):
|
||||
return compile_pattern(input_value)
|
||||
elif isinstance(input_value, str):
|
||||
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
PatternType = typing.TypeVar('PatternType', str, bytes)
|
||||
|
||||
|
||||
def compile_pattern(pattern: PatternType) -> typing.Pattern[PatternType]:
|
||||
try:
|
||||
return re.compile(pattern)
|
||||
except re.error:
|
||||
raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
|
||||
|
||||
|
||||
def ip_v4_address_validator(input_value: Any, /) -> IPv4Address:
|
||||
if isinstance(input_value, IPv4Address):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Address(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
|
||||
|
||||
|
||||
def ip_v6_address_validator(input_value: Any, /) -> IPv6Address:
|
||||
if isinstance(input_value, IPv6Address):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Address(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
|
||||
|
||||
|
||||
def ip_v4_network_validator(input_value: Any, /) -> IPv4Network:
|
||||
"""Assume IPv4Network initialised with a default `strict` argument.
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
|
||||
"""
|
||||
if isinstance(input_value, IPv4Network):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Network(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
|
||||
|
||||
|
||||
def ip_v6_network_validator(input_value: Any, /) -> IPv6Network:
|
||||
"""Assume IPv6Network initialised with a default `strict` argument.
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
|
||||
"""
|
||||
if isinstance(input_value, IPv6Network):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Network(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
|
||||
|
||||
|
||||
def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface:
|
||||
if isinstance(input_value, IPv4Interface):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Interface(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
|
||||
|
||||
|
||||
def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface:
|
||||
if isinstance(input_value, IPv6Interface):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Interface(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
|
||||
|
||||
|
||||
def forbid_inf_nan_check(x: Any) -> Any:
|
||||
if not math.isfinite(x):
|
||||
raise PydanticKnownError('finite_number')
|
||||
return x
|
||||
|
||||
|
||||
_InputType = typing.TypeVar('_InputType')
|
||||
|
||||
|
||||
def create_constraint_validator(
|
||||
constraint_id: str,
|
||||
predicate: Callable[[_InputType, Any], bool],
|
||||
error_type: ErrorType,
|
||||
context_gen: Callable[[Any, Any], dict[str, Any]] | None = None,
|
||||
) -> Callable[[_InputType, Any], _InputType]:
|
||||
"""Create a validator function for a given constraint.
|
||||
|
||||
Args:
|
||||
constraint_id: The constraint identifier, used to identify the constraint in error messages, ex 'gt'.
|
||||
predicate: The predicate function to apply to the input value, ex `lambda x, gt: x > gt`.
|
||||
error_type: The error type to raise if the predicate fails.
|
||||
context_gen: A function to generate the error context from the constraint value and the input value.
|
||||
"""
|
||||
|
||||
def validator(x: _InputType, constraint_value: Any) -> _InputType:
|
||||
try:
|
||||
if not predicate(x, constraint_value):
|
||||
raise PydanticKnownError(
|
||||
error_type, context_gen(constraint_value, x) if context_gen else {constraint_id: constraint_value}
|
||||
)
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint '{constraint_id}' to supplied value {x}")
|
||||
return x
|
||||
|
||||
return validator
|
||||
|
||||
|
||||
_CONSTRAINT_TO_VALIDATOR_LOOKUP: dict[str, Callable] = {
|
||||
'gt': create_constraint_validator('gt', lambda x, gt: x > gt, 'greater_than'),
|
||||
'ge': create_constraint_validator('ge', lambda x, ge: x >= ge, 'greater_than_equal'),
|
||||
'lt': create_constraint_validator('lt', lambda x, lt: x < lt, 'less_than'),
|
||||
'le': create_constraint_validator('le', lambda x, le: x <= le, 'less_than_equal'),
|
||||
'multiple_of': create_constraint_validator(
|
||||
'multiple_of', lambda x, multiple_of: x % multiple_of == 0, 'multiple_of'
|
||||
),
|
||||
'min_length': create_constraint_validator(
|
||||
'min_length',
|
||||
lambda x, min_length: len(x) >= min_length,
|
||||
'too_short',
|
||||
lambda c_val, x: {'field_type': 'Value', 'min_length': c_val, 'actual_length': len(x)},
|
||||
),
|
||||
'max_length': create_constraint_validator(
|
||||
'max_length',
|
||||
lambda x, max_length: len(x) <= max_length,
|
||||
'too_long',
|
||||
lambda c_val, x: {'field_type': 'Value', 'max_length': c_val, 'actual_length': len(x)},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_constraint_validator(constraint: str) -> Callable:
|
||||
"""Fetch the validator function for the given constraint."""
|
||||
try:
|
||||
return _CONSTRAINT_TO_VALIDATOR_LOOKUP[constraint]
|
||||
except KeyError:
|
||||
raise TypeError(f'Unknown constraint {constraint}')
|
||||
|
||||
|
||||
IP_VALIDATOR_LOOKUP: dict[type, Callable] = {
|
||||
IPv4Address: ip_v4_address_validator,
|
||||
IPv6Address: ip_v6_address_validator,
|
||||
IPv4Network: ip_v4_network_validator,
|
||||
IPv6Network: ip_v6_network_validator,
|
||||
IPv4Interface: ip_v4_interface_validator,
|
||||
IPv6Interface: ip_v6_interface_validator,
|
||||
}
|
||||
308
venv/lib/python3.11/site-packages/pydantic/_migration.py
Normal file
308
venv/lib/python3.11/site-packages/pydantic/_migration.py
Normal file
@@ -0,0 +1,308 @@
|
||||
import sys
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from .version import version_short
|
||||
|
||||
MOVED_IN_V2 = {
|
||||
'pydantic.utils:version_info': 'pydantic.version:version_info',
|
||||
'pydantic.error_wrappers:ValidationError': 'pydantic:ValidationError',
|
||||
'pydantic.utils:to_camel': 'pydantic.alias_generators:to_pascal',
|
||||
'pydantic.utils:to_lower_camel': 'pydantic.alias_generators:to_camel',
|
||||
'pydantic:PyObject': 'pydantic.types:ImportString',
|
||||
'pydantic.types:PyObject': 'pydantic.types:ImportString',
|
||||
'pydantic.generics:GenericModel': 'pydantic.BaseModel',
|
||||
}
|
||||
|
||||
DEPRECATED_MOVED_IN_V2 = {
|
||||
'pydantic.tools:schema_of': 'pydantic.deprecated.tools:schema_of',
|
||||
'pydantic.tools:parse_obj_as': 'pydantic.deprecated.tools:parse_obj_as',
|
||||
'pydantic.tools:schema_json_of': 'pydantic.deprecated.tools:schema_json_of',
|
||||
'pydantic.json:pydantic_encoder': 'pydantic.deprecated.json:pydantic_encoder',
|
||||
'pydantic:validate_arguments': 'pydantic.deprecated.decorator:validate_arguments',
|
||||
'pydantic.json:custom_pydantic_encoder': 'pydantic.deprecated.json:custom_pydantic_encoder',
|
||||
'pydantic.json:timedelta_isoformat': 'pydantic.deprecated.json:timedelta_isoformat',
|
||||
'pydantic.decorator:validate_arguments': 'pydantic.deprecated.decorator:validate_arguments',
|
||||
'pydantic.class_validators:validator': 'pydantic.deprecated.class_validators:validator',
|
||||
'pydantic.class_validators:root_validator': 'pydantic.deprecated.class_validators:root_validator',
|
||||
'pydantic.config:BaseConfig': 'pydantic.deprecated.config:BaseConfig',
|
||||
'pydantic.config:Extra': 'pydantic.deprecated.config:Extra',
|
||||
}
|
||||
|
||||
REDIRECT_TO_V1 = {
|
||||
f'pydantic.utils:{obj}': f'pydantic.v1.utils:{obj}'
|
||||
for obj in (
|
||||
'deep_update',
|
||||
'GetterDict',
|
||||
'lenient_issubclass',
|
||||
'lenient_isinstance',
|
||||
'is_valid_field',
|
||||
'update_not_none',
|
||||
'import_string',
|
||||
'Representation',
|
||||
'ROOT_KEY',
|
||||
'smart_deepcopy',
|
||||
'sequence_like',
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
REMOVED_IN_V2 = {
|
||||
'pydantic:ConstrainedBytes',
|
||||
'pydantic:ConstrainedDate',
|
||||
'pydantic:ConstrainedDecimal',
|
||||
'pydantic:ConstrainedFloat',
|
||||
'pydantic:ConstrainedFrozenSet',
|
||||
'pydantic:ConstrainedInt',
|
||||
'pydantic:ConstrainedList',
|
||||
'pydantic:ConstrainedSet',
|
||||
'pydantic:ConstrainedStr',
|
||||
'pydantic:JsonWrapper',
|
||||
'pydantic:NoneBytes',
|
||||
'pydantic:NoneStr',
|
||||
'pydantic:NoneStrBytes',
|
||||
'pydantic:Protocol',
|
||||
'pydantic:Required',
|
||||
'pydantic:StrBytes',
|
||||
'pydantic:compiled',
|
||||
'pydantic.config:get_config',
|
||||
'pydantic.config:inherit_config',
|
||||
'pydantic.config:prepare_config',
|
||||
'pydantic:create_model_from_namedtuple',
|
||||
'pydantic:create_model_from_typeddict',
|
||||
'pydantic.dataclasses:create_pydantic_model_from_dataclass',
|
||||
'pydantic.dataclasses:make_dataclass_validator',
|
||||
'pydantic.dataclasses:set_validation',
|
||||
'pydantic.datetime_parse:parse_date',
|
||||
'pydantic.datetime_parse:parse_time',
|
||||
'pydantic.datetime_parse:parse_datetime',
|
||||
'pydantic.datetime_parse:parse_duration',
|
||||
'pydantic.error_wrappers:ErrorWrapper',
|
||||
'pydantic.errors:AnyStrMaxLengthError',
|
||||
'pydantic.errors:AnyStrMinLengthError',
|
||||
'pydantic.errors:ArbitraryTypeError',
|
||||
'pydantic.errors:BoolError',
|
||||
'pydantic.errors:BytesError',
|
||||
'pydantic.errors:CallableError',
|
||||
'pydantic.errors:ClassError',
|
||||
'pydantic.errors:ColorError',
|
||||
'pydantic.errors:ConfigError',
|
||||
'pydantic.errors:DataclassTypeError',
|
||||
'pydantic.errors:DateError',
|
||||
'pydantic.errors:DateNotInTheFutureError',
|
||||
'pydantic.errors:DateNotInThePastError',
|
||||
'pydantic.errors:DateTimeError',
|
||||
'pydantic.errors:DecimalError',
|
||||
'pydantic.errors:DecimalIsNotFiniteError',
|
||||
'pydantic.errors:DecimalMaxDigitsError',
|
||||
'pydantic.errors:DecimalMaxPlacesError',
|
||||
'pydantic.errors:DecimalWholeDigitsError',
|
||||
'pydantic.errors:DictError',
|
||||
'pydantic.errors:DurationError',
|
||||
'pydantic.errors:EmailError',
|
||||
'pydantic.errors:EnumError',
|
||||
'pydantic.errors:EnumMemberError',
|
||||
'pydantic.errors:ExtraError',
|
||||
'pydantic.errors:FloatError',
|
||||
'pydantic.errors:FrozenSetError',
|
||||
'pydantic.errors:FrozenSetMaxLengthError',
|
||||
'pydantic.errors:FrozenSetMinLengthError',
|
||||
'pydantic.errors:HashableError',
|
||||
'pydantic.errors:IPv4AddressError',
|
||||
'pydantic.errors:IPv4InterfaceError',
|
||||
'pydantic.errors:IPv4NetworkError',
|
||||
'pydantic.errors:IPv6AddressError',
|
||||
'pydantic.errors:IPv6InterfaceError',
|
||||
'pydantic.errors:IPv6NetworkError',
|
||||
'pydantic.errors:IPvAnyAddressError',
|
||||
'pydantic.errors:IPvAnyInterfaceError',
|
||||
'pydantic.errors:IPvAnyNetworkError',
|
||||
'pydantic.errors:IntEnumError',
|
||||
'pydantic.errors:IntegerError',
|
||||
'pydantic.errors:InvalidByteSize',
|
||||
'pydantic.errors:InvalidByteSizeUnit',
|
||||
'pydantic.errors:InvalidDiscriminator',
|
||||
'pydantic.errors:InvalidLengthForBrand',
|
||||
'pydantic.errors:JsonError',
|
||||
'pydantic.errors:JsonTypeError',
|
||||
'pydantic.errors:ListError',
|
||||
'pydantic.errors:ListMaxLengthError',
|
||||
'pydantic.errors:ListMinLengthError',
|
||||
'pydantic.errors:ListUniqueItemsError',
|
||||
'pydantic.errors:LuhnValidationError',
|
||||
'pydantic.errors:MissingDiscriminator',
|
||||
'pydantic.errors:MissingError',
|
||||
'pydantic.errors:NoneIsAllowedError',
|
||||
'pydantic.errors:NoneIsNotAllowedError',
|
||||
'pydantic.errors:NotDigitError',
|
||||
'pydantic.errors:NotNoneError',
|
||||
'pydantic.errors:NumberNotGeError',
|
||||
'pydantic.errors:NumberNotGtError',
|
||||
'pydantic.errors:NumberNotLeError',
|
||||
'pydantic.errors:NumberNotLtError',
|
||||
'pydantic.errors:NumberNotMultipleError',
|
||||
'pydantic.errors:PathError',
|
||||
'pydantic.errors:PathNotADirectoryError',
|
||||
'pydantic.errors:PathNotAFileError',
|
||||
'pydantic.errors:PathNotExistsError',
|
||||
'pydantic.errors:PatternError',
|
||||
'pydantic.errors:PyObjectError',
|
||||
'pydantic.errors:PydanticTypeError',
|
||||
'pydantic.errors:PydanticValueError',
|
||||
'pydantic.errors:SequenceError',
|
||||
'pydantic.errors:SetError',
|
||||
'pydantic.errors:SetMaxLengthError',
|
||||
'pydantic.errors:SetMinLengthError',
|
||||
'pydantic.errors:StrError',
|
||||
'pydantic.errors:StrRegexError',
|
||||
'pydantic.errors:StrictBoolError',
|
||||
'pydantic.errors:SubclassError',
|
||||
'pydantic.errors:TimeError',
|
||||
'pydantic.errors:TupleError',
|
||||
'pydantic.errors:TupleLengthError',
|
||||
'pydantic.errors:UUIDError',
|
||||
'pydantic.errors:UUIDVersionError',
|
||||
'pydantic.errors:UrlError',
|
||||
'pydantic.errors:UrlExtraError',
|
||||
'pydantic.errors:UrlHostError',
|
||||
'pydantic.errors:UrlHostTldError',
|
||||
'pydantic.errors:UrlPortError',
|
||||
'pydantic.errors:UrlSchemeError',
|
||||
'pydantic.errors:UrlSchemePermittedError',
|
||||
'pydantic.errors:UrlUserInfoError',
|
||||
'pydantic.errors:WrongConstantError',
|
||||
'pydantic.main:validate_model',
|
||||
'pydantic.networks:stricturl',
|
||||
'pydantic:parse_file_as',
|
||||
'pydantic:parse_raw_as',
|
||||
'pydantic:stricturl',
|
||||
'pydantic.tools:parse_file_as',
|
||||
'pydantic.tools:parse_raw_as',
|
||||
'pydantic.types:ConstrainedBytes',
|
||||
'pydantic.types:ConstrainedDate',
|
||||
'pydantic.types:ConstrainedDecimal',
|
||||
'pydantic.types:ConstrainedFloat',
|
||||
'pydantic.types:ConstrainedFrozenSet',
|
||||
'pydantic.types:ConstrainedInt',
|
||||
'pydantic.types:ConstrainedList',
|
||||
'pydantic.types:ConstrainedSet',
|
||||
'pydantic.types:ConstrainedStr',
|
||||
'pydantic.types:JsonWrapper',
|
||||
'pydantic.types:NoneBytes',
|
||||
'pydantic.types:NoneStr',
|
||||
'pydantic.types:NoneStrBytes',
|
||||
'pydantic.types:StrBytes',
|
||||
'pydantic.typing:evaluate_forwardref',
|
||||
'pydantic.typing:AbstractSetIntStr',
|
||||
'pydantic.typing:AnyCallable',
|
||||
'pydantic.typing:AnyClassMethod',
|
||||
'pydantic.typing:CallableGenerator',
|
||||
'pydantic.typing:DictAny',
|
||||
'pydantic.typing:DictIntStrAny',
|
||||
'pydantic.typing:DictStrAny',
|
||||
'pydantic.typing:IntStr',
|
||||
'pydantic.typing:ListStr',
|
||||
'pydantic.typing:MappingIntStrAny',
|
||||
'pydantic.typing:NoArgAnyCallable',
|
||||
'pydantic.typing:NoneType',
|
||||
'pydantic.typing:ReprArgs',
|
||||
'pydantic.typing:SetStr',
|
||||
'pydantic.typing:StrPath',
|
||||
'pydantic.typing:TupleGenerator',
|
||||
'pydantic.typing:WithArgsTypes',
|
||||
'pydantic.typing:all_literal_values',
|
||||
'pydantic.typing:display_as_type',
|
||||
'pydantic.typing:get_all_type_hints',
|
||||
'pydantic.typing:get_args',
|
||||
'pydantic.typing:get_origin',
|
||||
'pydantic.typing:get_sub_types',
|
||||
'pydantic.typing:is_callable_type',
|
||||
'pydantic.typing:is_classvar',
|
||||
'pydantic.typing:is_finalvar',
|
||||
'pydantic.typing:is_literal_type',
|
||||
'pydantic.typing:is_namedtuple',
|
||||
'pydantic.typing:is_new_type',
|
||||
'pydantic.typing:is_none_type',
|
||||
'pydantic.typing:is_typeddict',
|
||||
'pydantic.typing:is_typeddict_special',
|
||||
'pydantic.typing:is_union',
|
||||
'pydantic.typing:new_type_supertype',
|
||||
'pydantic.typing:resolve_annotations',
|
||||
'pydantic.typing:typing_base',
|
||||
'pydantic.typing:update_field_forward_refs',
|
||||
'pydantic.typing:update_model_forward_refs',
|
||||
'pydantic.utils:ClassAttribute',
|
||||
'pydantic.utils:DUNDER_ATTRIBUTES',
|
||||
'pydantic.utils:PyObjectStr',
|
||||
'pydantic.utils:ValueItems',
|
||||
'pydantic.utils:almost_equal_floats',
|
||||
'pydantic.utils:get_discriminator_alias_and_values',
|
||||
'pydantic.utils:get_model',
|
||||
'pydantic.utils:get_unique_discriminator_alias',
|
||||
'pydantic.utils:in_ipython',
|
||||
'pydantic.utils:is_valid_identifier',
|
||||
'pydantic.utils:path_type',
|
||||
'pydantic.utils:validate_field_name',
|
||||
'pydantic:validate_model',
|
||||
}
|
||||
|
||||
|
||||
def getattr_migration(module: str) -> Callable[[str], Any]:
|
||||
"""Implement PEP 562 for objects that were either moved or removed on the migration
|
||||
to V2.
|
||||
|
||||
Args:
|
||||
module: The module name.
|
||||
|
||||
Returns:
|
||||
A callable that will raise an error if the object is not found.
|
||||
"""
|
||||
# This avoids circular import with errors.py.
|
||||
from .errors import PydanticImportError
|
||||
|
||||
def wrapper(name: str) -> object:
|
||||
"""Raise an error if the object is not found, or warn if it was moved.
|
||||
|
||||
In case it was moved, it still returns the object.
|
||||
|
||||
Args:
|
||||
name: The object name.
|
||||
|
||||
Returns:
|
||||
The object.
|
||||
"""
|
||||
if name == '__path__':
|
||||
raise AttributeError(f'module {module!r} has no attribute {name!r}')
|
||||
|
||||
import warnings
|
||||
|
||||
from ._internal._validators import import_string
|
||||
|
||||
import_path = f'{module}:{name}'
|
||||
if import_path in MOVED_IN_V2.keys():
|
||||
new_location = MOVED_IN_V2[import_path]
|
||||
warnings.warn(f'`{import_path}` has been moved to `{new_location}`.')
|
||||
return import_string(MOVED_IN_V2[import_path])
|
||||
if import_path in DEPRECATED_MOVED_IN_V2:
|
||||
# skip the warning here because a deprecation warning will be raised elsewhere
|
||||
return import_string(DEPRECATED_MOVED_IN_V2[import_path])
|
||||
if import_path in REDIRECT_TO_V1:
|
||||
new_location = REDIRECT_TO_V1[import_path]
|
||||
warnings.warn(
|
||||
f'`{import_path}` has been removed. We are importing from `{new_location}` instead.'
|
||||
'See the migration guide for more details: https://docs.pydantic.dev/latest/migration/'
|
||||
)
|
||||
return import_string(REDIRECT_TO_V1[import_path])
|
||||
if import_path == 'pydantic:BaseSettings':
|
||||
raise PydanticImportError(
|
||||
'`BaseSettings` has been moved to the `pydantic-settings` package. '
|
||||
f'See https://docs.pydantic.dev/{version_short()}/migration/#basesettings-has-moved-to-pydantic-settings '
|
||||
'for more details.'
|
||||
)
|
||||
if import_path in REMOVED_IN_V2:
|
||||
raise PydanticImportError(f'`{import_path}` has been removed in V2.')
|
||||
globals: Dict[str, Any] = sys.modules[module].__dict__
|
||||
if name in globals:
|
||||
return globals[name]
|
||||
raise AttributeError(f'module {module!r} has no attribute {name!r}')
|
||||
|
||||
return wrapper
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Alias generators for converting between different capitalization conventions."""
|
||||
|
||||
import re
|
||||
|
||||
__all__ = ('to_pascal', 'to_camel', 'to_snake')
|
||||
|
||||
# TODO: in V3, change the argument names to be more descriptive
|
||||
# Generally, don't only convert from snake_case, or name the functions
|
||||
# more specifically like snake_to_camel.
|
||||
|
||||
|
||||
def to_pascal(snake: str) -> str:
|
||||
"""Convert a snake_case string to PascalCase.
|
||||
|
||||
Args:
|
||||
snake: The string to convert.
|
||||
|
||||
Returns:
|
||||
The PascalCase string.
|
||||
"""
|
||||
camel = snake.title()
|
||||
return re.sub('([0-9A-Za-z])_(?=[0-9A-Z])', lambda m: m.group(1), camel)
|
||||
|
||||
|
||||
def to_camel(snake: str) -> str:
|
||||
"""Convert a snake_case string to camelCase.
|
||||
|
||||
Args:
|
||||
snake: The string to convert.
|
||||
|
||||
Returns:
|
||||
The converted camelCase string.
|
||||
"""
|
||||
# If the string is already in camelCase and does not contain a digit followed
|
||||
# by a lowercase letter, return it as it is
|
||||
if re.match('^[a-z]+[A-Za-z0-9]*$', snake) and not re.search(r'\d[a-z]', snake):
|
||||
return snake
|
||||
|
||||
camel = to_pascal(snake)
|
||||
return re.sub('(^_*[A-Z])', lambda m: m.group(1).lower(), camel)
|
||||
|
||||
|
||||
def to_snake(camel: str) -> str:
|
||||
"""Convert a PascalCase, camelCase, or kebab-case string to snake_case.
|
||||
|
||||
Args:
|
||||
camel: The string to convert.
|
||||
|
||||
Returns:
|
||||
The converted string in snake_case.
|
||||
"""
|
||||
# Handle the sequence of uppercase letters followed by a lowercase letter
|
||||
snake = re.sub(r'([A-Z]+)([A-Z][a-z])', lambda m: f'{m.group(1)}_{m.group(2)}', camel)
|
||||
# Insert an underscore between a lowercase letter and an uppercase letter
|
||||
snake = re.sub(r'([a-z])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
|
||||
# Insert an underscore between a digit and an uppercase letter
|
||||
snake = re.sub(r'([0-9])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
|
||||
# Insert an underscore between a lowercase letter and a digit
|
||||
snake = re.sub(r'([a-z])([0-9])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
|
||||
# Replace hyphens with underscores to handle kebab-case
|
||||
snake = snake.replace('-', '_')
|
||||
return snake.lower()
|
||||
132
venv/lib/python3.11/site-packages/pydantic/aliases.py
Normal file
132
venv/lib/python3.11/site-packages/pydantic/aliases.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Support for alias configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from ._internal import _internal_dataclass
|
||||
|
||||
__all__ = ('AliasGenerator', 'AliasPath', 'AliasChoices')
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class AliasPath:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/alias#aliaspath-and-aliaschoices
|
||||
|
||||
A data class used by `validation_alias` as a convenience to create aliases.
|
||||
|
||||
Attributes:
|
||||
path: A list of string or integer aliases.
|
||||
"""
|
||||
|
||||
path: list[int | str]
|
||||
|
||||
def __init__(self, first_arg: str, *args: str | int) -> None:
|
||||
self.path = [first_arg] + list(args)
|
||||
|
||||
def convert_to_aliases(self) -> list[str | int]:
|
||||
"""Converts arguments to a list of string or integer aliases.
|
||||
|
||||
Returns:
|
||||
The list of aliases.
|
||||
"""
|
||||
return self.path
|
||||
|
||||
def search_dict_for_path(self, d: dict) -> Any:
|
||||
"""Searches a dictionary for the path specified by the alias.
|
||||
|
||||
Returns:
|
||||
The value at the specified path, or `PydanticUndefined` if the path is not found.
|
||||
"""
|
||||
v = d
|
||||
for k in self.path:
|
||||
if isinstance(v, str):
|
||||
# disallow indexing into a str, like for AliasPath('x', 0) and x='abc'
|
||||
return PydanticUndefined
|
||||
try:
|
||||
v = v[k]
|
||||
except (KeyError, IndexError, TypeError):
|
||||
return PydanticUndefined
|
||||
return v
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class AliasChoices:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/alias#aliaspath-and-aliaschoices
|
||||
|
||||
A data class used by `validation_alias` as a convenience to create aliases.
|
||||
|
||||
Attributes:
|
||||
choices: A list containing a string or `AliasPath`.
|
||||
"""
|
||||
|
||||
choices: list[str | AliasPath]
|
||||
|
||||
def __init__(self, first_choice: str | AliasPath, *choices: str | AliasPath) -> None:
|
||||
self.choices = [first_choice] + list(choices)
|
||||
|
||||
def convert_to_aliases(self) -> list[list[str | int]]:
|
||||
"""Converts arguments to a list of lists containing string or integer aliases.
|
||||
|
||||
Returns:
|
||||
The list of aliases.
|
||||
"""
|
||||
aliases: list[list[str | int]] = []
|
||||
for c in self.choices:
|
||||
if isinstance(c, AliasPath):
|
||||
aliases.append(c.convert_to_aliases())
|
||||
else:
|
||||
aliases.append([c])
|
||||
return aliases
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class AliasGenerator:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/alias#using-an-aliasgenerator
|
||||
|
||||
A data class used by `alias_generator` as a convenience to create various aliases.
|
||||
|
||||
Attributes:
|
||||
alias: A callable that takes a field name and returns an alias for it.
|
||||
validation_alias: A callable that takes a field name and returns a validation alias for it.
|
||||
serialization_alias: A callable that takes a field name and returns a serialization alias for it.
|
||||
"""
|
||||
|
||||
alias: Callable[[str], str] | None = None
|
||||
validation_alias: Callable[[str], str | AliasPath | AliasChoices] | None = None
|
||||
serialization_alias: Callable[[str], str] | None = None
|
||||
|
||||
def _generate_alias(
|
||||
self,
|
||||
alias_kind: Literal['alias', 'validation_alias', 'serialization_alias'],
|
||||
allowed_types: tuple[type[str] | type[AliasPath] | type[AliasChoices], ...],
|
||||
field_name: str,
|
||||
) -> str | AliasPath | AliasChoices | None:
|
||||
"""Generate an alias of the specified kind. Returns None if the alias generator is None.
|
||||
|
||||
Raises:
|
||||
TypeError: If the alias generator produces an invalid type.
|
||||
"""
|
||||
alias = None
|
||||
if alias_generator := getattr(self, alias_kind):
|
||||
alias = alias_generator(field_name)
|
||||
if alias and not isinstance(alias, allowed_types):
|
||||
raise TypeError(
|
||||
f'Invalid `{alias_kind}` type. `{alias_kind}` generator must produce one of `{allowed_types}`'
|
||||
)
|
||||
return alias
|
||||
|
||||
def generate_aliases(self, field_name: str) -> tuple[str | None, str | AliasPath | AliasChoices | None, str | None]:
|
||||
"""Generate `alias`, `validation_alias`, and `serialization_alias` for a field.
|
||||
|
||||
Returns:
|
||||
A tuple of three aliases - validation, alias, and serialization.
|
||||
"""
|
||||
alias = self._generate_alias('alias', (str,), field_name)
|
||||
validation_alias = self._generate_alias('validation_alias', (str, AliasChoices, AliasPath), field_name)
|
||||
serialization_alias = self._generate_alias('serialization_alias', (str,), field_name)
|
||||
|
||||
return alias, validation_alias, serialization_alias # type: ignore
|
||||
121
venv/lib/python3.11/site-packages/pydantic/annotated_handlers.py
Normal file
121
venv/lib/python3.11/site-packages/pydantic/annotated_handlers.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Type annotations to use with `__get_pydantic_core_schema__` and `__get_pydantic_json_schema__`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from pydantic_core import core_schema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .json_schema import JsonSchemaMode, JsonSchemaValue
|
||||
|
||||
CoreSchemaOrField = Union[
|
||||
core_schema.CoreSchema,
|
||||
core_schema.ModelField,
|
||||
core_schema.DataclassField,
|
||||
core_schema.TypedDictField,
|
||||
core_schema.ComputedField,
|
||||
]
|
||||
|
||||
__all__ = 'GetJsonSchemaHandler', 'GetCoreSchemaHandler'
|
||||
|
||||
|
||||
class GetJsonSchemaHandler:
|
||||
"""Handler to call into the next JSON schema generation function.
|
||||
|
||||
Attributes:
|
||||
mode: Json schema mode, can be `validation` or `serialization`.
|
||||
"""
|
||||
|
||||
mode: JsonSchemaMode
|
||||
|
||||
def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue:
|
||||
"""Call the inner handler and get the JsonSchemaValue it returns.
|
||||
This will call the next JSON schema modifying function up until it calls
|
||||
into `pydantic.json_schema.GenerateJsonSchema`, which will raise a
|
||||
`pydantic.errors.PydanticInvalidForJsonSchema` error if it cannot generate
|
||||
a JSON schema.
|
||||
|
||||
Args:
|
||||
core_schema: A `pydantic_core.core_schema.CoreSchema`.
|
||||
|
||||
Returns:
|
||||
JsonSchemaValue: The JSON schema generated by the inner JSON schema modify
|
||||
functions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue, /) -> JsonSchemaValue:
|
||||
"""Get the real schema for a `{"$ref": ...}` schema.
|
||||
If the schema given is not a `$ref` schema, it will be returned as is.
|
||||
This means you don't have to check before calling this function.
|
||||
|
||||
Args:
|
||||
maybe_ref_json_schema: A JsonSchemaValue which may be a `$ref` schema.
|
||||
|
||||
Raises:
|
||||
LookupError: If the ref is not found.
|
||||
|
||||
Returns:
|
||||
JsonSchemaValue: A JsonSchemaValue that has no `$ref`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GetCoreSchemaHandler:
|
||||
"""Handler to call into the next CoreSchema schema generation function."""
|
||||
|
||||
def __call__(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
"""Call the inner handler and get the CoreSchema it returns.
|
||||
This will call the next CoreSchema modifying function up until it calls
|
||||
into Pydantic's internal schema generation machinery, which will raise a
|
||||
`pydantic.errors.PydanticSchemaGenerationError` error if it cannot generate
|
||||
a CoreSchema for the given source type.
|
||||
|
||||
Args:
|
||||
source_type: The input type.
|
||||
|
||||
Returns:
|
||||
CoreSchema: The `pydantic-core` CoreSchema generated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
"""Generate a schema unrelated to the current context.
|
||||
Use this function if e.g. you are handling schema generation for a sequence
|
||||
and want to generate a schema for its items.
|
||||
Otherwise, you may end up doing something like applying a `min_length` constraint
|
||||
that was intended for the sequence itself to its items!
|
||||
|
||||
Args:
|
||||
source_type: The input type.
|
||||
|
||||
Returns:
|
||||
CoreSchema: The `pydantic-core` CoreSchema generated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_schema: core_schema.CoreSchema, /) -> core_schema.CoreSchema:
|
||||
"""Get the real schema for a `definition-ref` schema.
|
||||
If the schema given is not a `definition-ref` schema, it will be returned as is.
|
||||
This means you don't have to check before calling this function.
|
||||
|
||||
Args:
|
||||
maybe_ref_schema: A `CoreSchema`, `ref`-based or not.
|
||||
|
||||
Raises:
|
||||
LookupError: If the `ref` is not found.
|
||||
|
||||
Returns:
|
||||
A concrete `CoreSchema`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def field_name(self) -> str | None:
|
||||
"""Get the name of the closest field to this validator."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_types_namespace(self) -> dict[str, Any] | None:
|
||||
"""Internal method used during type resolution for serializer annotations."""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,5 @@
|
||||
"""`class_validators` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
604
venv/lib/python3.11/site-packages/pydantic/color.py
Normal file
604
venv/lib/python3.11/site-packages/pydantic/color.py
Normal file
@@ -0,0 +1,604 @@
|
||||
"""Color definitions are used as per the CSS3
|
||||
[CSS Color Module Level 3](http://www.w3.org/TR/css3-color/#svg-color) specification.
|
||||
|
||||
A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`.
|
||||
|
||||
In these cases the _last_ color when sorted alphabetically takes preferences,
|
||||
eg. `Color((0, 255, 255)).as_named() == 'cyan'` because "cyan" comes after "aqua".
|
||||
|
||||
Warning: Deprecated
|
||||
The `Color` class is deprecated, use `pydantic_extra_types` instead.
|
||||
See [`pydantic-extra-types.Color`](../usage/types/extra_types/color_types.md)
|
||||
for more information.
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
from colorsys import hls_to_rgb, rgb_to_hls
|
||||
from typing import Any, Callable, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from pydantic_core import CoreSchema, PydanticCustomError, core_schema
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from ._internal import _repr
|
||||
from ._internal._schema_generation_shared import GetJsonSchemaHandler as _GetJsonSchemaHandler
|
||||
from .json_schema import JsonSchemaValue
|
||||
from .warnings import PydanticDeprecatedSince20
|
||||
|
||||
ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]]
|
||||
ColorType = Union[ColorTuple, str]
|
||||
HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]]
|
||||
|
||||
|
||||
class RGBA:
|
||||
"""Internal use only as a representation of a color."""
|
||||
|
||||
__slots__ = 'r', 'g', 'b', 'alpha', '_tuple'
|
||||
|
||||
def __init__(self, r: float, g: float, b: float, alpha: Optional[float]):
|
||||
self.r = r
|
||||
self.g = g
|
||||
self.b = b
|
||||
self.alpha = alpha
|
||||
|
||||
self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha)
|
||||
|
||||
def __getitem__(self, item: Any) -> Any:
|
||||
return self._tuple[item]
|
||||
|
||||
|
||||
# these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached
|
||||
_r_255 = r'(\d{1,3}(?:\.\d+)?)'
|
||||
_r_comma = r'\s*,\s*'
|
||||
_r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)'
|
||||
_r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?'
|
||||
_r_sl = r'(\d{1,3}(?:\.\d+)?)%'
|
||||
r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*'
|
||||
r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*'
|
||||
# CSS3 RGB examples: rgb(0, 0, 0), rgba(0, 0, 0, 0.5), rgba(0, 0, 0, 50%)
|
||||
r_rgb = rf'\s*rgba?\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
|
||||
# CSS3 HSL examples: hsl(270, 60%, 50%), hsla(270, 60%, 50%, 0.5), hsla(270, 60%, 50%, 50%)
|
||||
r_hsl = rf'\s*hsla?\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
|
||||
# CSS4 RGB examples: rgb(0 0 0), rgb(0 0 0 / 0.5), rgb(0 0 0 / 50%), rgba(0 0 0 / 50%)
|
||||
r_rgb_v4_style = rf'\s*rgba?\(\s*{_r_255}\s+{_r_255}\s+{_r_255}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
|
||||
# CSS4 HSL examples: hsl(270 60% 50%), hsl(270 60% 50% / 0.5), hsl(270 60% 50% / 50%), hsla(270 60% 50% / 50%)
|
||||
r_hsl_v4_style = rf'\s*hsla?\(\s*{_r_h}\s+{_r_sl}\s+{_r_sl}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
|
||||
|
||||
# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used
|
||||
repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'}
|
||||
rads = 2 * math.pi
|
||||
|
||||
|
||||
@deprecated(
|
||||
'The `Color` class is deprecated, use `pydantic_extra_types` instead. '
|
||||
'See https://docs.pydantic.dev/latest/api/pydantic_extra_types_color/.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
)
|
||||
class Color(_repr.Representation):
|
||||
"""Represents a color."""
|
||||
|
||||
__slots__ = '_original', '_rgba'
|
||||
|
||||
def __init__(self, value: ColorType) -> None:
|
||||
self._rgba: RGBA
|
||||
self._original: ColorType
|
||||
if isinstance(value, (tuple, list)):
|
||||
self._rgba = parse_tuple(value)
|
||||
elif isinstance(value, str):
|
||||
self._rgba = parse_str(value)
|
||||
elif isinstance(value, Color):
|
||||
self._rgba = value._rgba
|
||||
value = value._original
|
||||
else:
|
||||
raise PydanticCustomError(
|
||||
'color_error', 'value is not a valid color: value must be a tuple, list or string'
|
||||
)
|
||||
|
||||
# if we've got here value must be a valid color
|
||||
self._original = value
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: _GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
field_schema = {}
|
||||
field_schema.update(type='string', format='color')
|
||||
return field_schema
|
||||
|
||||
def original(self) -> ColorType:
|
||||
"""Original value passed to `Color`."""
|
||||
return self._original
|
||||
|
||||
def as_named(self, *, fallback: bool = False) -> str:
|
||||
"""Returns the name of the color if it can be found in `COLORS_BY_VALUE` dictionary,
|
||||
otherwise returns the hexadecimal representation of the color or raises `ValueError`.
|
||||
|
||||
Args:
|
||||
fallback: If True, falls back to returning the hexadecimal representation of
|
||||
the color instead of raising a ValueError when no named color is found.
|
||||
|
||||
Returns:
|
||||
The name of the color, or the hexadecimal representation of the color.
|
||||
|
||||
Raises:
|
||||
ValueError: When no named color is found and fallback is `False`.
|
||||
"""
|
||||
if self._rgba.alpha is None:
|
||||
rgb = cast(Tuple[int, int, int], self.as_rgb_tuple())
|
||||
try:
|
||||
return COLORS_BY_VALUE[rgb]
|
||||
except KeyError as e:
|
||||
if fallback:
|
||||
return self.as_hex()
|
||||
else:
|
||||
raise ValueError('no named color found, use fallback=True, as_hex() or as_rgb()') from e
|
||||
else:
|
||||
return self.as_hex()
|
||||
|
||||
def as_hex(self) -> str:
|
||||
"""Returns the hexadecimal representation of the color.
|
||||
|
||||
Hex string representing the color can be 3, 4, 6, or 8 characters depending on whether the string
|
||||
a "short" representation of the color is possible and whether there's an alpha channel.
|
||||
|
||||
Returns:
|
||||
The hexadecimal representation of the color.
|
||||
"""
|
||||
values = [float_to_255(c) for c in self._rgba[:3]]
|
||||
if self._rgba.alpha is not None:
|
||||
values.append(float_to_255(self._rgba.alpha))
|
||||
|
||||
as_hex = ''.join(f'{v:02x}' for v in values)
|
||||
if all(c in repeat_colors for c in values):
|
||||
as_hex = ''.join(as_hex[c] for c in range(0, len(as_hex), 2))
|
||||
return '#' + as_hex
|
||||
|
||||
def as_rgb(self) -> str:
|
||||
"""Color as an `rgb(<r>, <g>, <b>)` or `rgba(<r>, <g>, <b>, <a>)` string."""
|
||||
if self._rgba.alpha is None:
|
||||
return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})'
|
||||
else:
|
||||
return (
|
||||
f'rgba({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)}, '
|
||||
f'{round(self._alpha_float(), 2)})'
|
||||
)
|
||||
|
||||
def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple:
|
||||
"""Returns the color as an RGB or RGBA tuple.
|
||||
|
||||
Args:
|
||||
alpha: Whether to include the alpha channel. There are three options for this input:
|
||||
|
||||
- `None` (default): Include alpha only if it's set. (e.g. not `None`)
|
||||
- `True`: Always include alpha.
|
||||
- `False`: Always omit alpha.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the values of the red, green, and blue channels in the range 0 to 255.
|
||||
If alpha is included, it is in the range 0 to 1.
|
||||
"""
|
||||
r, g, b = (float_to_255(c) for c in self._rgba[:3])
|
||||
if alpha is None:
|
||||
if self._rgba.alpha is None:
|
||||
return r, g, b
|
||||
else:
|
||||
return r, g, b, self._alpha_float()
|
||||
elif alpha:
|
||||
return r, g, b, self._alpha_float()
|
||||
else:
|
||||
# alpha is False
|
||||
return r, g, b
|
||||
|
||||
def as_hsl(self) -> str:
|
||||
"""Color as an `hsl(<h>, <s>, <l>)` or `hsl(<h>, <s>, <l>, <a>)` string."""
|
||||
if self._rgba.alpha is None:
|
||||
h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore
|
||||
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})'
|
||||
else:
|
||||
h, s, li, a = self.as_hsl_tuple(alpha=True) # type: ignore
|
||||
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})'
|
||||
|
||||
def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple:
|
||||
"""Returns the color as an HSL or HSLA tuple.
|
||||
|
||||
Args:
|
||||
alpha: Whether to include the alpha channel.
|
||||
|
||||
- `None` (default): Include the alpha channel only if it's set (e.g. not `None`).
|
||||
- `True`: Always include alpha.
|
||||
- `False`: Always omit alpha.
|
||||
|
||||
Returns:
|
||||
The color as a tuple of hue, saturation, lightness, and alpha (if included).
|
||||
All elements are in the range 0 to 1.
|
||||
|
||||
Note:
|
||||
This is HSL as used in HTML and most other places, not HLS as used in Python's `colorsys`.
|
||||
"""
|
||||
h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b) # noqa: E741
|
||||
if alpha is None:
|
||||
if self._rgba.alpha is None:
|
||||
return h, s, l
|
||||
else:
|
||||
return h, s, l, self._alpha_float()
|
||||
if alpha:
|
||||
return h, s, l, self._alpha_float()
|
||||
else:
|
||||
# alpha is False
|
||||
return h, s, l
|
||||
|
||||
def _alpha_float(self) -> float:
|
||||
return 1 if self._rgba.alpha is None else self._rgba.alpha
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source: Type[Any], handler: Callable[[Any], CoreSchema]
|
||||
) -> core_schema.CoreSchema:
|
||||
return core_schema.with_info_plain_validator_function(
|
||||
cls._validate, serialization=core_schema.to_string_ser_schema()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, __input_value: Any, _: Any) -> 'Color':
|
||||
return cls(__input_value)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.as_named(fallback=True)
|
||||
|
||||
def __repr_args__(self) -> '_repr.ReprArgs':
|
||||
return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())]
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.as_rgb_tuple())
|
||||
|
||||
|
||||
def parse_tuple(value: Tuple[Any, ...]) -> RGBA:
|
||||
"""Parse a tuple or list to get RGBA values.
|
||||
|
||||
Args:
|
||||
value: A tuple or list.
|
||||
|
||||
Returns:
|
||||
An `RGBA` tuple parsed from the input tuple.
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: If tuple is not valid.
|
||||
"""
|
||||
if len(value) == 3:
|
||||
r, g, b = (parse_color_value(v) for v in value)
|
||||
return RGBA(r, g, b, None)
|
||||
elif len(value) == 4:
|
||||
r, g, b = (parse_color_value(v) for v in value[:3])
|
||||
return RGBA(r, g, b, parse_float_alpha(value[3]))
|
||||
else:
|
||||
raise PydanticCustomError('color_error', 'value is not a valid color: tuples must have length 3 or 4')
|
||||
|
||||
|
||||
def parse_str(value: str) -> RGBA:
|
||||
"""Parse a string representing a color to an RGBA tuple.
|
||||
|
||||
Possible formats for the input string include:
|
||||
|
||||
* named color, see `COLORS_BY_NAME`
|
||||
* hex short eg. `<prefix>fff` (prefix can be `#`, `0x` or nothing)
|
||||
* hex long eg. `<prefix>ffffff` (prefix can be `#`, `0x` or nothing)
|
||||
* `rgb(<r>, <g>, <b>)`
|
||||
* `rgba(<r>, <g>, <b>, <a>)`
|
||||
|
||||
Args:
|
||||
value: A string representing a color.
|
||||
|
||||
Returns:
|
||||
An `RGBA` tuple parsed from the input string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input string cannot be parsed to an RGBA tuple.
|
||||
"""
|
||||
value_lower = value.lower()
|
||||
try:
|
||||
r, g, b = COLORS_BY_NAME[value_lower]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return ints_to_rgba(r, g, b, None)
|
||||
|
||||
m = re.fullmatch(r_hex_short, value_lower)
|
||||
if m:
|
||||
*rgb, a = m.groups()
|
||||
r, g, b = (int(v * 2, 16) for v in rgb)
|
||||
if a:
|
||||
alpha: Optional[float] = int(a * 2, 16) / 255
|
||||
else:
|
||||
alpha = None
|
||||
return ints_to_rgba(r, g, b, alpha)
|
||||
|
||||
m = re.fullmatch(r_hex_long, value_lower)
|
||||
if m:
|
||||
*rgb, a = m.groups()
|
||||
r, g, b = (int(v, 16) for v in rgb)
|
||||
if a:
|
||||
alpha = int(a, 16) / 255
|
||||
else:
|
||||
alpha = None
|
||||
return ints_to_rgba(r, g, b, alpha)
|
||||
|
||||
m = re.fullmatch(r_rgb, value_lower) or re.fullmatch(r_rgb_v4_style, value_lower)
|
||||
if m:
|
||||
return ints_to_rgba(*m.groups()) # type: ignore
|
||||
|
||||
m = re.fullmatch(r_hsl, value_lower) or re.fullmatch(r_hsl_v4_style, value_lower)
|
||||
if m:
|
||||
return parse_hsl(*m.groups()) # type: ignore
|
||||
|
||||
raise PydanticCustomError('color_error', 'value is not a valid color: string not recognised as a valid color')
|
||||
|
||||
|
||||
def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float] = None) -> RGBA:
|
||||
"""Converts integer or string values for RGB color and an optional alpha value to an `RGBA` object.
|
||||
|
||||
Args:
|
||||
r: An integer or string representing the red color value.
|
||||
g: An integer or string representing the green color value.
|
||||
b: An integer or string representing the blue color value.
|
||||
alpha: A float representing the alpha value. Defaults to None.
|
||||
|
||||
Returns:
|
||||
An instance of the `RGBA` class with the corresponding color and alpha values.
|
||||
"""
|
||||
return RGBA(parse_color_value(r), parse_color_value(g), parse_color_value(b), parse_float_alpha(alpha))
|
||||
|
||||
|
||||
def parse_color_value(value: Union[int, str], max_val: int = 255) -> float:
|
||||
"""Parse the color value provided and return a number between 0 and 1.
|
||||
|
||||
Args:
|
||||
value: An integer or string color value.
|
||||
max_val: Maximum range value. Defaults to 255.
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: If the value is not a valid color.
|
||||
|
||||
Returns:
|
||||
A number between 0 and 1.
|
||||
"""
|
||||
try:
|
||||
color = float(value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('color_error', 'value is not a valid color: color values must be a valid number')
|
||||
if 0 <= color <= max_val:
|
||||
return color / max_val
|
||||
else:
|
||||
raise PydanticCustomError(
|
||||
'color_error',
|
||||
'value is not a valid color: color values must be in the range 0 to {max_val}',
|
||||
{'max_val': max_val},
|
||||
)
|
||||
|
||||
|
||||
def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]:
|
||||
"""Parse an alpha value checking it's a valid float in the range 0 to 1.
|
||||
|
||||
Args:
|
||||
value: The input value to parse.
|
||||
|
||||
Returns:
|
||||
The parsed value as a float, or `None` if the value was None or equal 1.
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: If the input value cannot be successfully parsed as a float in the expected range.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
if isinstance(value, str) and value.endswith('%'):
|
||||
alpha = float(value[:-1]) / 100
|
||||
else:
|
||||
alpha = float(value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('color_error', 'value is not a valid color: alpha values must be a valid float')
|
||||
|
||||
if math.isclose(alpha, 1):
|
||||
return None
|
||||
elif 0 <= alpha <= 1:
|
||||
return alpha
|
||||
else:
|
||||
raise PydanticCustomError('color_error', 'value is not a valid color: alpha values must be in the range 0 to 1')
|
||||
|
||||
|
||||
def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] = None) -> RGBA:
|
||||
"""Parse raw hue, saturation, lightness, and alpha values and convert to RGBA.
|
||||
|
||||
Args:
|
||||
h: The hue value.
|
||||
h_units: The unit for hue value.
|
||||
sat: The saturation value.
|
||||
light: The lightness value.
|
||||
alpha: Alpha value.
|
||||
|
||||
Returns:
|
||||
An instance of `RGBA`.
|
||||
"""
|
||||
s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100)
|
||||
|
||||
h_value = float(h)
|
||||
if h_units in {None, 'deg'}:
|
||||
h_value = h_value % 360 / 360
|
||||
elif h_units == 'rad':
|
||||
h_value = h_value % rads / rads
|
||||
else:
|
||||
# turns
|
||||
h_value = h_value % 1
|
||||
|
||||
r, g, b = hls_to_rgb(h_value, l_value, s_value)
|
||||
return RGBA(r, g, b, parse_float_alpha(alpha))
|
||||
|
||||
|
||||
def float_to_255(c: float) -> int:
|
||||
"""Converts a float value between 0 and 1 (inclusive) to an integer between 0 and 255 (inclusive).
|
||||
|
||||
Args:
|
||||
c: The float value to be converted. Must be between 0 and 1 (inclusive).
|
||||
|
||||
Returns:
|
||||
The integer equivalent of the given float value rounded to the nearest whole number.
|
||||
|
||||
Raises:
|
||||
ValueError: If the given float value is outside the acceptable range of 0 to 1 (inclusive).
|
||||
"""
|
||||
return int(round(c * 255))
|
||||
|
||||
|
||||
COLORS_BY_NAME = {
|
||||
'aliceblue': (240, 248, 255),
|
||||
'antiquewhite': (250, 235, 215),
|
||||
'aqua': (0, 255, 255),
|
||||
'aquamarine': (127, 255, 212),
|
||||
'azure': (240, 255, 255),
|
||||
'beige': (245, 245, 220),
|
||||
'bisque': (255, 228, 196),
|
||||
'black': (0, 0, 0),
|
||||
'blanchedalmond': (255, 235, 205),
|
||||
'blue': (0, 0, 255),
|
||||
'blueviolet': (138, 43, 226),
|
||||
'brown': (165, 42, 42),
|
||||
'burlywood': (222, 184, 135),
|
||||
'cadetblue': (95, 158, 160),
|
||||
'chartreuse': (127, 255, 0),
|
||||
'chocolate': (210, 105, 30),
|
||||
'coral': (255, 127, 80),
|
||||
'cornflowerblue': (100, 149, 237),
|
||||
'cornsilk': (255, 248, 220),
|
||||
'crimson': (220, 20, 60),
|
||||
'cyan': (0, 255, 255),
|
||||
'darkblue': (0, 0, 139),
|
||||
'darkcyan': (0, 139, 139),
|
||||
'darkgoldenrod': (184, 134, 11),
|
||||
'darkgray': (169, 169, 169),
|
||||
'darkgreen': (0, 100, 0),
|
||||
'darkgrey': (169, 169, 169),
|
||||
'darkkhaki': (189, 183, 107),
|
||||
'darkmagenta': (139, 0, 139),
|
||||
'darkolivegreen': (85, 107, 47),
|
||||
'darkorange': (255, 140, 0),
|
||||
'darkorchid': (153, 50, 204),
|
||||
'darkred': (139, 0, 0),
|
||||
'darksalmon': (233, 150, 122),
|
||||
'darkseagreen': (143, 188, 143),
|
||||
'darkslateblue': (72, 61, 139),
|
||||
'darkslategray': (47, 79, 79),
|
||||
'darkslategrey': (47, 79, 79),
|
||||
'darkturquoise': (0, 206, 209),
|
||||
'darkviolet': (148, 0, 211),
|
||||
'deeppink': (255, 20, 147),
|
||||
'deepskyblue': (0, 191, 255),
|
||||
'dimgray': (105, 105, 105),
|
||||
'dimgrey': (105, 105, 105),
|
||||
'dodgerblue': (30, 144, 255),
|
||||
'firebrick': (178, 34, 34),
|
||||
'floralwhite': (255, 250, 240),
|
||||
'forestgreen': (34, 139, 34),
|
||||
'fuchsia': (255, 0, 255),
|
||||
'gainsboro': (220, 220, 220),
|
||||
'ghostwhite': (248, 248, 255),
|
||||
'gold': (255, 215, 0),
|
||||
'goldenrod': (218, 165, 32),
|
||||
'gray': (128, 128, 128),
|
||||
'green': (0, 128, 0),
|
||||
'greenyellow': (173, 255, 47),
|
||||
'grey': (128, 128, 128),
|
||||
'honeydew': (240, 255, 240),
|
||||
'hotpink': (255, 105, 180),
|
||||
'indianred': (205, 92, 92),
|
||||
'indigo': (75, 0, 130),
|
||||
'ivory': (255, 255, 240),
|
||||
'khaki': (240, 230, 140),
|
||||
'lavender': (230, 230, 250),
|
||||
'lavenderblush': (255, 240, 245),
|
||||
'lawngreen': (124, 252, 0),
|
||||
'lemonchiffon': (255, 250, 205),
|
||||
'lightblue': (173, 216, 230),
|
||||
'lightcoral': (240, 128, 128),
|
||||
'lightcyan': (224, 255, 255),
|
||||
'lightgoldenrodyellow': (250, 250, 210),
|
||||
'lightgray': (211, 211, 211),
|
||||
'lightgreen': (144, 238, 144),
|
||||
'lightgrey': (211, 211, 211),
|
||||
'lightpink': (255, 182, 193),
|
||||
'lightsalmon': (255, 160, 122),
|
||||
'lightseagreen': (32, 178, 170),
|
||||
'lightskyblue': (135, 206, 250),
|
||||
'lightslategray': (119, 136, 153),
|
||||
'lightslategrey': (119, 136, 153),
|
||||
'lightsteelblue': (176, 196, 222),
|
||||
'lightyellow': (255, 255, 224),
|
||||
'lime': (0, 255, 0),
|
||||
'limegreen': (50, 205, 50),
|
||||
'linen': (250, 240, 230),
|
||||
'magenta': (255, 0, 255),
|
||||
'maroon': (128, 0, 0),
|
||||
'mediumaquamarine': (102, 205, 170),
|
||||
'mediumblue': (0, 0, 205),
|
||||
'mediumorchid': (186, 85, 211),
|
||||
'mediumpurple': (147, 112, 219),
|
||||
'mediumseagreen': (60, 179, 113),
|
||||
'mediumslateblue': (123, 104, 238),
|
||||
'mediumspringgreen': (0, 250, 154),
|
||||
'mediumturquoise': (72, 209, 204),
|
||||
'mediumvioletred': (199, 21, 133),
|
||||
'midnightblue': (25, 25, 112),
|
||||
'mintcream': (245, 255, 250),
|
||||
'mistyrose': (255, 228, 225),
|
||||
'moccasin': (255, 228, 181),
|
||||
'navajowhite': (255, 222, 173),
|
||||
'navy': (0, 0, 128),
|
||||
'oldlace': (253, 245, 230),
|
||||
'olive': (128, 128, 0),
|
||||
'olivedrab': (107, 142, 35),
|
||||
'orange': (255, 165, 0),
|
||||
'orangered': (255, 69, 0),
|
||||
'orchid': (218, 112, 214),
|
||||
'palegoldenrod': (238, 232, 170),
|
||||
'palegreen': (152, 251, 152),
|
||||
'paleturquoise': (175, 238, 238),
|
||||
'palevioletred': (219, 112, 147),
|
||||
'papayawhip': (255, 239, 213),
|
||||
'peachpuff': (255, 218, 185),
|
||||
'peru': (205, 133, 63),
|
||||
'pink': (255, 192, 203),
|
||||
'plum': (221, 160, 221),
|
||||
'powderblue': (176, 224, 230),
|
||||
'purple': (128, 0, 128),
|
||||
'red': (255, 0, 0),
|
||||
'rosybrown': (188, 143, 143),
|
||||
'royalblue': (65, 105, 225),
|
||||
'saddlebrown': (139, 69, 19),
|
||||
'salmon': (250, 128, 114),
|
||||
'sandybrown': (244, 164, 96),
|
||||
'seagreen': (46, 139, 87),
|
||||
'seashell': (255, 245, 238),
|
||||
'sienna': (160, 82, 45),
|
||||
'silver': (192, 192, 192),
|
||||
'skyblue': (135, 206, 235),
|
||||
'slateblue': (106, 90, 205),
|
||||
'slategray': (112, 128, 144),
|
||||
'slategrey': (112, 128, 144),
|
||||
'snow': (255, 250, 250),
|
||||
'springgreen': (0, 255, 127),
|
||||
'steelblue': (70, 130, 180),
|
||||
'tan': (210, 180, 140),
|
||||
'teal': (0, 128, 128),
|
||||
'thistle': (216, 191, 216),
|
||||
'tomato': (255, 99, 71),
|
||||
'turquoise': (64, 224, 208),
|
||||
'violet': (238, 130, 238),
|
||||
'wheat': (245, 222, 179),
|
||||
'white': (255, 255, 255),
|
||||
'whitesmoke': (245, 245, 245),
|
||||
'yellow': (255, 255, 0),
|
||||
'yellowgreen': (154, 205, 50),
|
||||
}
|
||||
|
||||
COLORS_BY_VALUE = {v: k for k, v in COLORS_BY_NAME.items()}
|
||||
1062
venv/lib/python3.11/site-packages/pydantic/config.py
Normal file
1062
venv/lib/python3.11/site-packages/pydantic/config.py
Normal file
File diff suppressed because it is too large
Load Diff
344
venv/lib/python3.11/site-packages/pydantic/dataclasses.py
Normal file
344
venv/lib/python3.11/site-packages/pydantic/dataclasses.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""Provide an enhanced dataclass that performs validation."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, TypeVar, overload
|
||||
from warnings import warn
|
||||
|
||||
from typing_extensions import Literal, TypeGuard, dataclass_transform
|
||||
|
||||
from ._internal import _config, _decorators, _typing_extra
|
||||
from ._internal import _dataclasses as _pydantic_dataclasses
|
||||
from ._migration import getattr_migration
|
||||
from .config import ConfigDict
|
||||
from .errors import PydanticUserError
|
||||
from .fields import Field, FieldInfo, PrivateAttr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._internal._dataclasses import PydanticDataclass
|
||||
|
||||
__all__ = 'dataclass', 'rebuild_dataclass'
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
kw_only: bool = ...,
|
||||
slots: bool = ...,
|
||||
) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore
|
||||
...
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: type[_T], # type: ignore
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
kw_only: bool = ...,
|
||||
slots: bool = ...,
|
||||
) -> type[PydanticDataclass]: ...
|
||||
|
||||
else:
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore
|
||||
...
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: type[_T], # type: ignore
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
) -> type[PydanticDataclass]: ...
|
||||
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
def dataclass(
|
||||
_cls: type[_T] | None = None,
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
kw_only: bool = False,
|
||||
slots: bool = False,
|
||||
) -> Callable[[type[_T]], type[PydanticDataclass]] | type[PydanticDataclass]:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/dataclasses/
|
||||
|
||||
A decorator used to create a Pydantic-enhanced dataclass, similar to the standard Python `dataclass`,
|
||||
but with added validation.
|
||||
|
||||
This function should be used similarly to `dataclasses.dataclass`.
|
||||
|
||||
Args:
|
||||
_cls: The target `dataclass`.
|
||||
init: Included for signature compatibility with `dataclasses.dataclass`, and is passed through to
|
||||
`dataclasses.dataclass` when appropriate. If specified, must be set to `False`, as pydantic inserts its
|
||||
own `__init__` function.
|
||||
repr: A boolean indicating whether to include the field in the `__repr__` output.
|
||||
eq: Determines if a `__eq__` method should be generated for the class.
|
||||
order: Determines if comparison magic methods should be generated, such as `__lt__`, but not `__eq__`.
|
||||
unsafe_hash: Determines if a `__hash__` method should be included in the class, as in `dataclasses.dataclass`.
|
||||
frozen: Determines if the generated class should be a 'frozen' `dataclass`, which does not allow its
|
||||
attributes to be modified after it has been initialized. If not set, the value from the provided `config` argument will be used (and will default to `False` otherwise).
|
||||
config: The Pydantic config to use for the `dataclass`.
|
||||
validate_on_init: A deprecated parameter included for backwards compatibility; in V2, all Pydantic dataclasses
|
||||
are validated on init.
|
||||
kw_only: Determines if `__init__` method parameters must be specified by keyword only. Defaults to `False`.
|
||||
slots: Determines if the generated class should be a 'slots' `dataclass`, which does not allow the addition of
|
||||
new attributes after instantiation.
|
||||
|
||||
Returns:
|
||||
A decorator that accepts a class as its argument and returns a Pydantic `dataclass`.
|
||||
|
||||
Raises:
|
||||
AssertionError: Raised if `init` is not `False` or `validate_on_init` is `False`.
|
||||
"""
|
||||
assert init is False, 'pydantic.dataclasses.dataclass only supports init=False'
|
||||
assert validate_on_init is not False, 'validate_on_init=False is no longer supported'
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
kwargs = {'kw_only': kw_only, 'slots': slots}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
def make_pydantic_fields_compatible(cls: type[Any]) -> None:
|
||||
"""Make sure that stdlib `dataclasses` understands `Field` kwargs like `kw_only`
|
||||
To do that, we simply change
|
||||
`x: int = pydantic.Field(..., kw_only=True)`
|
||||
into
|
||||
`x: int = dataclasses.field(default=pydantic.Field(..., kw_only=True), kw_only=True)`
|
||||
"""
|
||||
for annotation_cls in cls.__mro__:
|
||||
# In Python < 3.9, `__annotations__` might not be present if there are no fields.
|
||||
# we therefore need to use `getattr` to avoid an `AttributeError`.
|
||||
annotations = getattr(annotation_cls, '__annotations__', [])
|
||||
for field_name in annotations:
|
||||
field_value = getattr(cls, field_name, None)
|
||||
# Process only if this is an instance of `FieldInfo`.
|
||||
if not isinstance(field_value, FieldInfo):
|
||||
continue
|
||||
|
||||
# Initialize arguments for the standard `dataclasses.field`.
|
||||
field_args: dict = {'default': field_value}
|
||||
|
||||
# Handle `kw_only` for Python 3.10+
|
||||
if sys.version_info >= (3, 10) and field_value.kw_only:
|
||||
field_args['kw_only'] = True
|
||||
|
||||
# Set `repr` attribute if it's explicitly specified to be not `True`.
|
||||
if field_value.repr is not True:
|
||||
field_args['repr'] = field_value.repr
|
||||
|
||||
setattr(cls, field_name, dataclasses.field(**field_args))
|
||||
# In Python 3.8, dataclasses checks cls.__dict__['__annotations__'] for annotations,
|
||||
# so we must make sure it's initialized before we add to it.
|
||||
if cls.__dict__.get('__annotations__') is None:
|
||||
cls.__annotations__ = {}
|
||||
cls.__annotations__[field_name] = annotations[field_name]
|
||||
|
||||
def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]:
|
||||
"""Create a Pydantic dataclass from a regular dataclass.
|
||||
|
||||
Args:
|
||||
cls: The class to create the Pydantic dataclass from.
|
||||
|
||||
Returns:
|
||||
A Pydantic dataclass.
|
||||
"""
|
||||
from ._internal._utils import is_model_class
|
||||
|
||||
if is_model_class(cls):
|
||||
raise PydanticUserError(
|
||||
f'Cannot create a Pydantic dataclass from {cls.__name__} as it is already a Pydantic model',
|
||||
code='dataclass-on-model',
|
||||
)
|
||||
|
||||
original_cls = cls
|
||||
|
||||
# if config is not explicitly provided, try to read it from the type
|
||||
config_dict = config if config is not None else getattr(cls, '__pydantic_config__', None)
|
||||
config_wrapper = _config.ConfigWrapper(config_dict)
|
||||
decorators = _decorators.DecoratorInfos.build(cls)
|
||||
|
||||
# Keep track of the original __doc__ so that we can restore it after applying the dataclasses decorator
|
||||
# Otherwise, classes with no __doc__ will have their signature added into the JSON schema description,
|
||||
# since dataclasses.dataclass will set this as the __doc__
|
||||
original_doc = cls.__doc__
|
||||
|
||||
if _pydantic_dataclasses.is_builtin_dataclass(cls):
|
||||
# Don't preserve the docstring for vanilla dataclasses, as it may include the signature
|
||||
# This matches v1 behavior, and there was an explicit test for it
|
||||
original_doc = None
|
||||
|
||||
# We don't want to add validation to the existing std lib dataclass, so we will subclass it
|
||||
# If the class is generic, we need to make sure the subclass also inherits from Generic
|
||||
# with all the same parameters.
|
||||
bases = (cls,)
|
||||
if issubclass(cls, Generic):
|
||||
generic_base = Generic[cls.__parameters__] # type: ignore
|
||||
bases = bases + (generic_base,)
|
||||
cls = types.new_class(cls.__name__, bases)
|
||||
|
||||
make_pydantic_fields_compatible(cls)
|
||||
|
||||
# Respect frozen setting from dataclass constructor and fallback to config setting if not provided
|
||||
if frozen is not None:
|
||||
frozen_ = frozen
|
||||
if config_wrapper.frozen:
|
||||
# It's not recommended to define both, as the setting from the dataclass decorator will take priority.
|
||||
warn(
|
||||
f'`frozen` is set via both the `dataclass` decorator and `config` for dataclass {cls.__name__!r}.'
|
||||
'This is not recommended. The `frozen` specification on `dataclass` will take priority.',
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
frozen_ = config_wrapper.frozen or False
|
||||
|
||||
cls = dataclasses.dataclass( # type: ignore[call-overload]
|
||||
cls,
|
||||
# the value of init here doesn't affect anything except that it makes it easier to generate a signature
|
||||
init=True,
|
||||
repr=repr,
|
||||
eq=eq,
|
||||
order=order,
|
||||
unsafe_hash=unsafe_hash,
|
||||
frozen=frozen_,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
cls.__pydantic_decorators__ = decorators # type: ignore
|
||||
cls.__doc__ = original_doc
|
||||
cls.__module__ = original_cls.__module__
|
||||
cls.__qualname__ = original_cls.__qualname__
|
||||
cls.__pydantic_complete__ = False # `complete_dataclass` will set it to `True` if successful.
|
||||
_pydantic_dataclasses.complete_dataclass(cls, config_wrapper, raise_errors=False, types_namespace=None)
|
||||
return cls
|
||||
|
||||
return create_dataclass if _cls is None else create_dataclass(_cls)
|
||||
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
if (3, 8) <= sys.version_info < (3, 11):
|
||||
# Monkeypatch dataclasses.InitVar so that typing doesn't error if it occurs as a type when evaluating type hints
|
||||
# Starting in 3.11, typing.get_type_hints will not raise an error if the retrieved type hints are not callable.
|
||||
|
||||
def _call_initvar(*args: Any, **kwargs: Any) -> NoReturn:
|
||||
"""This function does nothing but raise an error that is as similar as possible to what you'd get
|
||||
if you were to try calling `InitVar[int]()` without this monkeypatch. The whole purpose is just
|
||||
to ensure typing._type_check does not error if the type hint evaluates to `InitVar[<parameter>]`.
|
||||
"""
|
||||
raise TypeError("'InitVar' object is not callable")
|
||||
|
||||
dataclasses.InitVar.__call__ = _call_initvar
|
||||
|
||||
|
||||
def rebuild_dataclass(
|
||||
cls: type[PydanticDataclass],
|
||||
*,
|
||||
force: bool = False,
|
||||
raise_errors: bool = True,
|
||||
_parent_namespace_depth: int = 2,
|
||||
_types_namespace: dict[str, Any] | None = None,
|
||||
) -> bool | None:
|
||||
"""Try to rebuild the pydantic-core schema for the dataclass.
|
||||
|
||||
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
|
||||
the initial attempt to build the schema, and automatic rebuilding fails.
|
||||
|
||||
This is analogous to `BaseModel.model_rebuild`.
|
||||
|
||||
Args:
|
||||
cls: The class to rebuild the pydantic-core schema for.
|
||||
force: Whether to force the rebuilding of the schema, defaults to `False`.
|
||||
raise_errors: Whether to raise errors, defaults to `True`.
|
||||
_parent_namespace_depth: The depth level of the parent namespace, defaults to 2.
|
||||
_types_namespace: The types namespace, defaults to `None`.
|
||||
|
||||
Returns:
|
||||
Returns `None` if the schema is already "complete" and rebuilding was not required.
|
||||
If rebuilding _was_ required, returns `True` if rebuilding was successful, otherwise `False`.
|
||||
"""
|
||||
if not force and cls.__pydantic_complete__:
|
||||
return None
|
||||
else:
|
||||
if _types_namespace is not None:
|
||||
types_namespace: dict[str, Any] | None = _types_namespace.copy()
|
||||
else:
|
||||
if _parent_namespace_depth > 0:
|
||||
frame_parent_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth) or {}
|
||||
# Note: we may need to add something similar to cls.__pydantic_parent_namespace__ from BaseModel
|
||||
# here when implementing handling of recursive generics. See BaseModel.model_rebuild for reference.
|
||||
types_namespace = frame_parent_ns
|
||||
else:
|
||||
types_namespace = {}
|
||||
|
||||
types_namespace = _typing_extra.merge_cls_and_parent_ns(cls, types_namespace)
|
||||
return _pydantic_dataclasses.complete_dataclass(
|
||||
cls,
|
||||
_config.ConfigWrapper(cls.__pydantic_config__, check=False),
|
||||
raise_errors=raise_errors,
|
||||
types_namespace=types_namespace,
|
||||
)
|
||||
|
||||
|
||||
def is_pydantic_dataclass(class_: type[Any], /) -> TypeGuard[type[PydanticDataclass]]:
|
||||
"""Whether a class is a pydantic dataclass.
|
||||
|
||||
Args:
|
||||
class_: The class.
|
||||
|
||||
Returns:
|
||||
`True` if the class is a pydantic dataclass, `False` otherwise.
|
||||
"""
|
||||
try:
|
||||
return '__pydantic_validator__' in class_.__dict__ and dataclasses.is_dataclass(class_)
|
||||
except AttributeError:
|
||||
return False
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `datetime_parse` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
5
venv/lib/python3.11/site-packages/pydantic/decorator.py
Normal file
5
venv/lib/python3.11/site-packages/pydantic/decorator.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""The `decorator` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Old `@validator` and `@root_validator` function validators from V1."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from functools import partial, partialmethod
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload
|
||||
from warnings import warn
|
||||
|
||||
from typing_extensions import Literal, Protocol, TypeAlias, deprecated
|
||||
|
||||
from .._internal import _decorators, _decorators_v1
|
||||
from ..errors import PydanticUserError
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
_ALLOW_REUSE_WARNING_MESSAGE = '`allow_reuse` is deprecated and will be ignored; it should no longer be necessary'
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _OnlyValueValidatorClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, __value: Any) -> Any: ...
|
||||
|
||||
class _V1ValidatorWithValuesClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, __value: Any, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
class _V1ValidatorWithValuesKwOnlyClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, __value: Any, *, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
class _V1ValidatorWithKwargsClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
class _V1ValidatorWithValuesAndKwargsClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, values: dict[str, Any], **kwargs: Any) -> Any: ...
|
||||
|
||||
class _V1RootValidatorClsMethod(Protocol):
|
||||
def __call__(
|
||||
self, __cls: Any, __values: _decorators_v1.RootValidatorValues
|
||||
) -> _decorators_v1.RootValidatorValues: ...
|
||||
|
||||
V1Validator = Union[
|
||||
_OnlyValueValidatorClsMethod,
|
||||
_V1ValidatorWithValuesClsMethod,
|
||||
_V1ValidatorWithValuesKwOnlyClsMethod,
|
||||
_V1ValidatorWithKwargsClsMethod,
|
||||
_V1ValidatorWithValuesAndKwargsClsMethod,
|
||||
_decorators_v1.V1ValidatorWithValues,
|
||||
_decorators_v1.V1ValidatorWithValuesKwOnly,
|
||||
_decorators_v1.V1ValidatorWithKwargs,
|
||||
_decorators_v1.V1ValidatorWithValuesAndKwargs,
|
||||
]
|
||||
|
||||
V1RootValidator = Union[
|
||||
_V1RootValidatorClsMethod,
|
||||
_decorators_v1.V1RootValidatorFunction,
|
||||
]
|
||||
|
||||
_PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]]
|
||||
|
||||
# Allow both a V1 (assumed pre=False) or V2 (assumed mode='after') validator
|
||||
# We lie to type checkers and say we return the same thing we get
|
||||
# but in reality we return a proxy object that _mostly_ behaves like the wrapped thing
|
||||
_V1ValidatorType = TypeVar('_V1ValidatorType', V1Validator, _PartialClsOrStaticMethod)
|
||||
_V1RootValidatorFunctionType = TypeVar(
|
||||
'_V1RootValidatorFunctionType',
|
||||
_decorators_v1.V1RootValidatorFunction,
|
||||
_V1RootValidatorClsMethod,
|
||||
_PartialClsOrStaticMethod,
|
||||
)
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
|
||||
@deprecated(
|
||||
'Pydantic V1 style `@validator` validators are deprecated.'
|
||||
' You should migrate to Pydantic V2 style `@field_validator` validators,'
|
||||
' see the migration guide for more details',
|
||||
category=None,
|
||||
)
|
||||
def validator(
|
||||
__field: str,
|
||||
*fields: str,
|
||||
pre: bool = False,
|
||||
each_item: bool = False,
|
||||
always: bool = False,
|
||||
check_fields: bool | None = None,
|
||||
allow_reuse: bool = False,
|
||||
) -> Callable[[_V1ValidatorType], _V1ValidatorType]:
|
||||
"""Decorate methods on the class indicating that they should be used to validate fields.
|
||||
|
||||
Args:
|
||||
__field (str): The first field the validator should be called on; this is separate
|
||||
from `fields` to ensure an error is raised if you don't pass at least one.
|
||||
*fields (str): Additional field(s) the validator should be called on.
|
||||
pre (bool, optional): Whether this validator should be called before the standard
|
||||
validators (else after). Defaults to False.
|
||||
each_item (bool, optional): For complex objects (sets, lists etc.) whether to validate
|
||||
individual elements rather than the whole object. Defaults to False.
|
||||
always (bool, optional): Whether this method and other validators should be called even if
|
||||
the value is missing. Defaults to False.
|
||||
check_fields (bool | None, optional): Whether to check that the fields actually exist on the model.
|
||||
Defaults to None.
|
||||
allow_reuse (bool, optional): Whether to track and raise an error if another validator refers to
|
||||
the decorated function. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Callable: A decorator that can be used to decorate a
|
||||
function to be used as a validator.
|
||||
"""
|
||||
warn(
|
||||
'Pydantic V1 style `@validator` validators are deprecated.'
|
||||
' You should migrate to Pydantic V2 style `@field_validator` validators,'
|
||||
' see the migration guide for more details',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if allow_reuse is True: # pragma: no cover
|
||||
warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning)
|
||||
fields = __field, *fields
|
||||
if isinstance(fields[0], FunctionType):
|
||||
raise PydanticUserError(
|
||||
'`@validator` should be used with fields and keyword arguments, not bare. '
|
||||
"E.g. usage should be `@validator('<field_name>', ...)`",
|
||||
code='validator-no-fields',
|
||||
)
|
||||
elif not all(isinstance(field, str) for field in fields):
|
||||
raise PydanticUserError(
|
||||
'`@validator` fields should be passed as separate string args. '
|
||||
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`",
|
||||
code='validator-invalid-fields',
|
||||
)
|
||||
|
||||
mode: Literal['before', 'after'] = 'before' if pre is True else 'after'
|
||||
|
||||
def dec(f: Any) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
if _decorators.is_instance_method_from_sig(f):
|
||||
raise PydanticUserError(
|
||||
'`@validator` cannot be applied to instance methods', code='validator-instance-method'
|
||||
)
|
||||
# auto apply the @classmethod decorator
|
||||
f = _decorators.ensure_classmethod_based_on_signature(f)
|
||||
wrap = _decorators_v1.make_generic_v1_field_validator
|
||||
validator_wrapper_info = _decorators.ValidatorDecoratorInfo(
|
||||
fields=fields,
|
||||
mode=mode,
|
||||
each_item=each_item,
|
||||
always=always,
|
||||
check_fields=check_fields,
|
||||
)
|
||||
return _decorators.PydanticDescriptorProxy(f, validator_wrapper_info, shim=wrap)
|
||||
|
||||
return dec # type: ignore[return-value]
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(
|
||||
*,
|
||||
# if you don't specify `pre` the default is `pre=False`
|
||||
# which means you need to specify `skip_on_failure=True`
|
||||
skip_on_failure: Literal[True],
|
||||
allow_reuse: bool = ...,
|
||||
) -> Callable[
|
||||
[_V1RootValidatorFunctionType],
|
||||
_V1RootValidatorFunctionType,
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(
|
||||
*,
|
||||
# if you specify `pre=True` then you don't need to specify
|
||||
# `skip_on_failure`, in fact it is not allowed as an argument!
|
||||
pre: Literal[True],
|
||||
allow_reuse: bool = ...,
|
||||
) -> Callable[
|
||||
[_V1RootValidatorFunctionType],
|
||||
_V1RootValidatorFunctionType,
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(
|
||||
*,
|
||||
# if you explicitly specify `pre=False` then you
|
||||
# MUST specify `skip_on_failure=True`
|
||||
pre: Literal[False],
|
||||
skip_on_failure: Literal[True],
|
||||
allow_reuse: bool = ...,
|
||||
) -> Callable[
|
||||
[_V1RootValidatorFunctionType],
|
||||
_V1RootValidatorFunctionType,
|
||||
]: ...
|
||||
|
||||
|
||||
@deprecated(
|
||||
'Pydantic V1 style `@root_validator` validators are deprecated.'
|
||||
' You should migrate to Pydantic V2 style `@model_validator` validators,'
|
||||
' see the migration guide for more details',
|
||||
category=None,
|
||||
)
|
||||
def root_validator(
|
||||
*__args,
|
||||
pre: bool = False,
|
||||
skip_on_failure: bool = False,
|
||||
allow_reuse: bool = False,
|
||||
) -> Any:
|
||||
"""Decorate methods on a model indicating that they should be used to validate (and perhaps
|
||||
modify) data either before or after standard model parsing/validation is performed.
|
||||
|
||||
Args:
|
||||
pre (bool, optional): Whether this validator should be called before the standard
|
||||
validators (else after). Defaults to False.
|
||||
skip_on_failure (bool, optional): Whether to stop validation and return as soon as a
|
||||
failure is encountered. Defaults to False.
|
||||
allow_reuse (bool, optional): Whether to track and raise an error if another validator
|
||||
refers to the decorated function. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Any: A decorator that can be used to decorate a function to be used as a root_validator.
|
||||
"""
|
||||
warn(
|
||||
'Pydantic V1 style `@root_validator` validators are deprecated.'
|
||||
' You should migrate to Pydantic V2 style `@model_validator` validators,'
|
||||
' see the migration guide for more details',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if __args:
|
||||
# Ensure a nice error is raised if someone attempts to use the bare decorator
|
||||
return root_validator()(*__args) # type: ignore
|
||||
|
||||
if allow_reuse is True: # pragma: no cover
|
||||
warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning)
|
||||
mode: Literal['before', 'after'] = 'before' if pre is True else 'after'
|
||||
if pre is False and skip_on_failure is not True:
|
||||
raise PydanticUserError(
|
||||
'If you use `@root_validator` with pre=False (the default) you MUST specify `skip_on_failure=True`.'
|
||||
' Note that `@root_validator` is deprecated and should be replaced with `@model_validator`.',
|
||||
code='root-validator-pre-skip',
|
||||
)
|
||||
|
||||
wrap = partial(_decorators_v1.make_v1_generic_root_validator, pre=pre)
|
||||
|
||||
def dec(f: Callable[..., Any] | classmethod[Any, Any, Any] | staticmethod[Any, Any]) -> Any:
|
||||
if _decorators.is_instance_method_from_sig(f):
|
||||
raise TypeError('`@root_validator` cannot be applied to instance methods')
|
||||
# auto apply the @classmethod decorator
|
||||
res = _decorators.ensure_classmethod_based_on_signature(f)
|
||||
dec_info = _decorators.RootValidatorDecoratorInfo(mode=mode)
|
||||
return _decorators.PydanticDescriptorProxy(res, dec_info, shim=wrap)
|
||||
|
||||
return dec
|
||||
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import Literal, deprecated
|
||||
|
||||
from .._internal import _config
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
__all__ = 'BaseConfig', 'Extra'
|
||||
|
||||
|
||||
class _ConfigMetaclass(type):
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
try:
|
||||
obj = _config.config_defaults[item]
|
||||
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
return obj
|
||||
except KeyError as exc:
|
||||
raise AttributeError(f"type object '{self.__name__}' has no attribute {exc}") from exc
|
||||
|
||||
|
||||
@deprecated('BaseConfig is deprecated. Use the `pydantic.ConfigDict` instead.', category=PydanticDeprecatedSince20)
|
||||
class BaseConfig(metaclass=_ConfigMetaclass):
|
||||
"""This class is only retained for backwards compatibility.
|
||||
|
||||
!!! Warning "Deprecated"
|
||||
BaseConfig is deprecated. Use the [`pydantic.ConfigDict`][pydantic.ConfigDict] instead.
|
||||
"""
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
try:
|
||||
obj = super().__getattribute__(item)
|
||||
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
return obj
|
||||
except AttributeError as exc:
|
||||
try:
|
||||
return getattr(type(self), item)
|
||||
except AttributeError:
|
||||
# re-raising changes the displayed text to reflect that `self` is not a type
|
||||
raise AttributeError(str(exc)) from exc
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
return super().__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
class _ExtraMeta(type):
|
||||
def __getattribute__(self, __name: str) -> Any:
|
||||
# The @deprecated decorator accesses other attributes, so we only emit a warning for the expected ones
|
||||
if __name in {'allow', 'ignore', 'forbid'}:
|
||||
warnings.warn(
|
||||
"`pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`)",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return super().__getattribute__(__name)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Extra is deprecated. Use literal values instead (e.g. `extra='allow'`)", category=PydanticDeprecatedSince20
|
||||
)
|
||||
class Extra(metaclass=_ExtraMeta):
|
||||
allow: Literal['allow'] = 'allow'
|
||||
ignore: Literal['ignore'] = 'ignore'
|
||||
forbid: Literal['forbid'] = 'forbid'
|
||||
@@ -0,0 +1,224 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import typing
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Tuple
|
||||
|
||||
import typing_extensions
|
||||
|
||||
from .._internal import (
|
||||
_model_construction,
|
||||
_typing_extra,
|
||||
_utils,
|
||||
)
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .. import BaseModel
|
||||
from .._internal._utils import AbstractSetIntStr, MappingIntStrAny
|
||||
|
||||
AnyClassMethod = classmethod[Any, Any, Any]
|
||||
TupleGenerator = typing.Generator[Tuple[str, Any], None, None]
|
||||
Model = typing.TypeVar('Model', bound='BaseModel')
|
||||
# should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope
|
||||
IncEx: typing_extensions.TypeAlias = 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None'
|
||||
|
||||
_object_setattr = _model_construction.object_setattr
|
||||
|
||||
|
||||
def _iter(
|
||||
self: BaseModel,
|
||||
to_dict: bool = False,
|
||||
by_alias: bool = False,
|
||||
include: AbstractSetIntStr | MappingIntStrAny | None = None,
|
||||
exclude: AbstractSetIntStr | MappingIntStrAny | None = None,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
) -> TupleGenerator:
|
||||
# Merge field set excludes with explicit exclude parameter with explicit overriding field set options.
|
||||
# The extra "is not None" guards are not logically necessary but optimizes performance for the simple case.
|
||||
if exclude is not None:
|
||||
exclude = _utils.ValueItems.merge(
|
||||
{k: v.exclude for k, v in self.model_fields.items() if v.exclude is not None}, exclude
|
||||
)
|
||||
|
||||
if include is not None:
|
||||
include = _utils.ValueItems.merge({k: True for k in self.model_fields}, include, intersect=True)
|
||||
|
||||
allowed_keys = _calculate_keys(self, include=include, exclude=exclude, exclude_unset=exclude_unset) # type: ignore
|
||||
if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none):
|
||||
# huge boost for plain _iter()
|
||||
yield from self.__dict__.items()
|
||||
if self.__pydantic_extra__:
|
||||
yield from self.__pydantic_extra__.items()
|
||||
return
|
||||
|
||||
value_exclude = _utils.ValueItems(self, exclude) if exclude is not None else None
|
||||
value_include = _utils.ValueItems(self, include) if include is not None else None
|
||||
|
||||
if self.__pydantic_extra__ is None:
|
||||
items = self.__dict__.items()
|
||||
else:
|
||||
items = list(self.__dict__.items()) + list(self.__pydantic_extra__.items())
|
||||
|
||||
for field_key, v in items:
|
||||
if (allowed_keys is not None and field_key not in allowed_keys) or (exclude_none and v is None):
|
||||
continue
|
||||
|
||||
if exclude_defaults:
|
||||
try:
|
||||
field = self.model_fields[field_key]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if not field.is_required() and field.default == v:
|
||||
continue
|
||||
|
||||
if by_alias and field_key in self.model_fields:
|
||||
dict_key = self.model_fields[field_key].alias or field_key
|
||||
else:
|
||||
dict_key = field_key
|
||||
|
||||
if to_dict or value_include or value_exclude:
|
||||
v = _get_value(
|
||||
type(self),
|
||||
v,
|
||||
to_dict=to_dict,
|
||||
by_alias=by_alias,
|
||||
include=value_include and value_include.for_element(field_key),
|
||||
exclude=value_exclude and value_exclude.for_element(field_key),
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
yield dict_key, v
|
||||
|
||||
|
||||
def _copy_and_set_values(
|
||||
self: Model,
|
||||
values: dict[str, Any],
|
||||
fields_set: set[str],
|
||||
extra: dict[str, Any] | None = None,
|
||||
private: dict[str, Any] | None = None,
|
||||
*,
|
||||
deep: bool, # UP006
|
||||
) -> Model:
|
||||
if deep:
|
||||
# chances of having empty dict here are quite low for using smart_deepcopy
|
||||
values = deepcopy(values)
|
||||
extra = deepcopy(extra)
|
||||
private = deepcopy(private)
|
||||
|
||||
cls = self.__class__
|
||||
m = cls.__new__(cls)
|
||||
_object_setattr(m, '__dict__', values)
|
||||
_object_setattr(m, '__pydantic_extra__', extra)
|
||||
_object_setattr(m, '__pydantic_fields_set__', fields_set)
|
||||
_object_setattr(m, '__pydantic_private__', private)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _get_value(
|
||||
cls: type[BaseModel],
|
||||
v: Any,
|
||||
to_dict: bool,
|
||||
by_alias: bool,
|
||||
include: AbstractSetIntStr | MappingIntStrAny | None,
|
||||
exclude: AbstractSetIntStr | MappingIntStrAny | None,
|
||||
exclude_unset: bool,
|
||||
exclude_defaults: bool,
|
||||
exclude_none: bool,
|
||||
) -> Any:
|
||||
from .. import BaseModel
|
||||
|
||||
if isinstance(v, BaseModel):
|
||||
if to_dict:
|
||||
return v.model_dump(
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
include=include, # type: ignore
|
||||
exclude=exclude, # type: ignore
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
else:
|
||||
return v.copy(include=include, exclude=exclude)
|
||||
|
||||
value_exclude = _utils.ValueItems(v, exclude) if exclude else None
|
||||
value_include = _utils.ValueItems(v, include) if include else None
|
||||
|
||||
if isinstance(v, dict):
|
||||
return {
|
||||
k_: _get_value(
|
||||
cls,
|
||||
v_,
|
||||
to_dict=to_dict,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
include=value_include and value_include.for_element(k_),
|
||||
exclude=value_exclude and value_exclude.for_element(k_),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
for k_, v_ in v.items()
|
||||
if (not value_exclude or not value_exclude.is_excluded(k_))
|
||||
and (not value_include or value_include.is_included(k_))
|
||||
}
|
||||
|
||||
elif _utils.sequence_like(v):
|
||||
seq_args = (
|
||||
_get_value(
|
||||
cls,
|
||||
v_,
|
||||
to_dict=to_dict,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
include=value_include and value_include.for_element(i),
|
||||
exclude=value_exclude and value_exclude.for_element(i),
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
for i, v_ in enumerate(v)
|
||||
if (not value_exclude or not value_exclude.is_excluded(i))
|
||||
and (not value_include or value_include.is_included(i))
|
||||
)
|
||||
|
||||
return v.__class__(*seq_args) if _typing_extra.is_namedtuple(v.__class__) else v.__class__(seq_args)
|
||||
|
||||
elif isinstance(v, Enum) and getattr(cls.model_config, 'use_enum_values', False):
|
||||
return v.value
|
||||
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
def _calculate_keys(
|
||||
self: BaseModel,
|
||||
include: MappingIntStrAny | None,
|
||||
exclude: MappingIntStrAny | None,
|
||||
exclude_unset: bool,
|
||||
update: typing.Dict[str, Any] | None = None, # noqa UP006
|
||||
) -> typing.AbstractSet[str] | None:
|
||||
if include is None and exclude is None and exclude_unset is False:
|
||||
return None
|
||||
|
||||
keys: typing.AbstractSet[str]
|
||||
if exclude_unset:
|
||||
keys = self.__pydantic_fields_set__.copy()
|
||||
else:
|
||||
keys = set(self.__dict__.keys())
|
||||
keys = keys | (self.__pydantic_extra__ or {}).keys()
|
||||
|
||||
if include is not None:
|
||||
keys &= include.keys()
|
||||
|
||||
if update:
|
||||
keys -= update.keys()
|
||||
|
||||
if exclude:
|
||||
keys -= {k for k, v in exclude.items() if _utils.ValueItems.is_true(v)}
|
||||
|
||||
return keys
|
||||
@@ -0,0 +1,279 @@
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from .._internal import _config, _typing_extra
|
||||
from ..alias_generators import to_pascal
|
||||
from ..errors import PydanticUserError
|
||||
from ..functional_validators import field_validator
|
||||
from ..main import BaseModel, create_model
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
__all__ = ('validate_arguments',)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
AnyCallable = Callable[..., Any]
|
||||
|
||||
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
|
||||
ConfigType = Union[None, Type[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(
|
||||
func: None = None, *, config: 'ConfigType' = None
|
||||
) -> Callable[['AnyCallableT'], 'AnyCallableT']: ...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT': ...
|
||||
|
||||
|
||||
@deprecated(
|
||||
'The `validate_arguments` method is deprecated; use `validate_call` instead.',
|
||||
category=None,
|
||||
)
|
||||
def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any:
|
||||
"""Decorator to validate the arguments passed to a function."""
|
||||
warnings.warn(
|
||||
'The `validate_arguments` method is deprecated; use `validate_call` instead.',
|
||||
PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def validate(_func: 'AnyCallable') -> 'AnyCallable':
|
||||
vd = ValidatedFunction(_func, config)
|
||||
|
||||
@wraps(_func)
|
||||
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
|
||||
return vd.call(*args, **kwargs)
|
||||
|
||||
wrapper_function.vd = vd # type: ignore
|
||||
wrapper_function.validate = vd.init_model_instance # type: ignore
|
||||
wrapper_function.raw_function = vd.raw_function # type: ignore
|
||||
wrapper_function.model = vd.model # type: ignore
|
||||
return wrapper_function
|
||||
|
||||
if func:
|
||||
return validate(func)
|
||||
else:
|
||||
return validate
|
||||
|
||||
|
||||
ALT_V_ARGS = 'v__args'
|
||||
ALT_V_KWARGS = 'v__kwargs'
|
||||
V_POSITIONAL_ONLY_NAME = 'v__positional_only'
|
||||
V_DUPLICATE_KWARGS = 'v__duplicate_kwargs'
|
||||
|
||||
|
||||
class ValidatedFunction:
|
||||
def __init__(self, function: 'AnyCallable', config: 'ConfigType'):
|
||||
from inspect import Parameter, signature
|
||||
|
||||
parameters: Mapping[str, Parameter] = signature(function).parameters
|
||||
|
||||
if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}:
|
||||
raise PydanticUserError(
|
||||
f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" '
|
||||
f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator',
|
||||
code=None,
|
||||
)
|
||||
|
||||
self.raw_function = function
|
||||
self.arg_mapping: Dict[int, str] = {}
|
||||
self.positional_only_args: set[str] = set()
|
||||
self.v_args_name = 'args'
|
||||
self.v_kwargs_name = 'kwargs'
|
||||
|
||||
type_hints = _typing_extra.get_type_hints(function, include_extras=True)
|
||||
takes_args = False
|
||||
takes_kwargs = False
|
||||
fields: Dict[str, Tuple[Any, Any]] = {}
|
||||
for i, (name, p) in enumerate(parameters.items()):
|
||||
if p.annotation is p.empty:
|
||||
annotation = Any
|
||||
else:
|
||||
annotation = type_hints[name]
|
||||
|
||||
default = ... if p.default is p.empty else p.default
|
||||
if p.kind == Parameter.POSITIONAL_ONLY:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_POSITIONAL_ONLY_NAME] = List[str], None
|
||||
self.positional_only_args.add(name)
|
||||
elif p.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_DUPLICATE_KWARGS] = List[str], None
|
||||
elif p.kind == Parameter.KEYWORD_ONLY:
|
||||
fields[name] = annotation, default
|
||||
elif p.kind == Parameter.VAR_POSITIONAL:
|
||||
self.v_args_name = name
|
||||
fields[name] = Tuple[annotation, ...], None
|
||||
takes_args = True
|
||||
else:
|
||||
assert p.kind == Parameter.VAR_KEYWORD, p.kind
|
||||
self.v_kwargs_name = name
|
||||
fields[name] = Dict[str, annotation], None
|
||||
takes_kwargs = True
|
||||
|
||||
# these checks avoid a clash between "args" and a field with that name
|
||||
if not takes_args and self.v_args_name in fields:
|
||||
self.v_args_name = ALT_V_ARGS
|
||||
|
||||
# same with "kwargs"
|
||||
if not takes_kwargs and self.v_kwargs_name in fields:
|
||||
self.v_kwargs_name = ALT_V_KWARGS
|
||||
|
||||
if not takes_args:
|
||||
# we add the field so validation below can raise the correct exception
|
||||
fields[self.v_args_name] = List[Any], None
|
||||
|
||||
if not takes_kwargs:
|
||||
# same with kwargs
|
||||
fields[self.v_kwargs_name] = Dict[Any, Any], None
|
||||
|
||||
self.create_model(fields, takes_args, takes_kwargs, config)
|
||||
|
||||
def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel:
|
||||
values = self.build_values(args, kwargs)
|
||||
return self.model(**values)
|
||||
|
||||
def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||
m = self.init_model_instance(*args, **kwargs)
|
||||
return self.execute(m)
|
||||
|
||||
def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
values: Dict[str, Any] = {}
|
||||
if args:
|
||||
arg_iter = enumerate(args)
|
||||
while True:
|
||||
try:
|
||||
i, a = next(arg_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
arg_name = self.arg_mapping.get(i)
|
||||
if arg_name is not None:
|
||||
values[arg_name] = a
|
||||
else:
|
||||
values[self.v_args_name] = [a] + [a for _, a in arg_iter]
|
||||
break
|
||||
|
||||
var_kwargs: Dict[str, Any] = {}
|
||||
wrong_positional_args = []
|
||||
duplicate_kwargs = []
|
||||
fields_alias = [
|
||||
field.alias
|
||||
for name, field in self.model.model_fields.items()
|
||||
if name not in (self.v_args_name, self.v_kwargs_name)
|
||||
]
|
||||
non_var_fields = set(self.model.model_fields) - {self.v_args_name, self.v_kwargs_name}
|
||||
for k, v in kwargs.items():
|
||||
if k in non_var_fields or k in fields_alias:
|
||||
if k in self.positional_only_args:
|
||||
wrong_positional_args.append(k)
|
||||
if k in values:
|
||||
duplicate_kwargs.append(k)
|
||||
values[k] = v
|
||||
else:
|
||||
var_kwargs[k] = v
|
||||
|
||||
if var_kwargs:
|
||||
values[self.v_kwargs_name] = var_kwargs
|
||||
if wrong_positional_args:
|
||||
values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args
|
||||
if duplicate_kwargs:
|
||||
values[V_DUPLICATE_KWARGS] = duplicate_kwargs
|
||||
return values
|
||||
|
||||
def execute(self, m: BaseModel) -> Any:
|
||||
d = {k: v for k, v in m.__dict__.items() if k in m.__pydantic_fields_set__ or m.model_fields[k].default_factory}
|
||||
var_kwargs = d.pop(self.v_kwargs_name, {})
|
||||
|
||||
if self.v_args_name in d:
|
||||
args_: List[Any] = []
|
||||
in_kwargs = False
|
||||
kwargs = {}
|
||||
for name, value in d.items():
|
||||
if in_kwargs:
|
||||
kwargs[name] = value
|
||||
elif name == self.v_args_name:
|
||||
args_ += value
|
||||
in_kwargs = True
|
||||
else:
|
||||
args_.append(value)
|
||||
return self.raw_function(*args_, **kwargs, **var_kwargs)
|
||||
elif self.positional_only_args:
|
||||
args_ = []
|
||||
kwargs = {}
|
||||
for name, value in d.items():
|
||||
if name in self.positional_only_args:
|
||||
args_.append(value)
|
||||
else:
|
||||
kwargs[name] = value
|
||||
return self.raw_function(*args_, **kwargs, **var_kwargs)
|
||||
else:
|
||||
return self.raw_function(**d, **var_kwargs)
|
||||
|
||||
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
|
||||
pos_args = len(self.arg_mapping)
|
||||
|
||||
config_wrapper = _config.ConfigWrapper(config)
|
||||
|
||||
if config_wrapper.alias_generator:
|
||||
raise PydanticUserError(
|
||||
'Setting the "alias_generator" property on custom Config for '
|
||||
'@validate_arguments is not yet supported, please remove.',
|
||||
code=None,
|
||||
)
|
||||
if config_wrapper.extra is None:
|
||||
config_wrapper.config_dict['extra'] = 'forbid'
|
||||
|
||||
class DecoratorBaseModel(BaseModel):
|
||||
@field_validator(self.v_args_name, check_fields=False)
|
||||
@classmethod
|
||||
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
|
||||
if takes_args or v is None:
|
||||
return v
|
||||
|
||||
raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given')
|
||||
|
||||
@field_validator(self.v_kwargs_name, check_fields=False)
|
||||
@classmethod
|
||||
def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if takes_kwargs or v is None:
|
||||
return v
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v.keys()))
|
||||
raise TypeError(f'unexpected keyword argument{plural}: {keys}')
|
||||
|
||||
@field_validator(V_POSITIONAL_ONLY_NAME, check_fields=False)
|
||||
@classmethod
|
||||
def check_positional_only(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v))
|
||||
raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}')
|
||||
|
||||
@field_validator(V_DUPLICATE_KWARGS, check_fields=False)
|
||||
@classmethod
|
||||
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v))
|
||||
raise TypeError(f'multiple values for argument{plural}: {keys}')
|
||||
|
||||
model_config = config_wrapper.config_dict
|
||||
|
||||
self.model = create_model(to_pascal(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)
|
||||
141
venv/lib/python3.11/site-packages/pydantic/deprecated/json.py
Normal file
141
venv/lib/python3.11/site-packages/pydantic/deprecated/json.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import datetime
|
||||
import warnings
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from types import GeneratorType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from .._internal._import_utils import import_cached_base_model
|
||||
from ..color import Color
|
||||
from ..networks import NameEmail
|
||||
from ..types import SecretBytes, SecretStr
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
|
||||
|
||||
|
||||
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
|
||||
return o.isoformat()
|
||||
|
||||
|
||||
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
||||
"""Encodes a Decimal as int of there's no exponent, otherwise float.
|
||||
|
||||
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
|
||||
where a integer (but not int typed) is used. Encoding this as a float
|
||||
results in failed round-tripping between encode and parse.
|
||||
Our Id type is a prime example of this.
|
||||
|
||||
>>> decimal_encoder(Decimal("1.0"))
|
||||
1.0
|
||||
|
||||
>>> decimal_encoder(Decimal("1"))
|
||||
1
|
||||
"""
|
||||
exponent = dec_value.as_tuple().exponent
|
||||
if isinstance(exponent, int) and exponent >= 0:
|
||||
return int(dec_value)
|
||||
else:
|
||||
return float(dec_value)
|
||||
|
||||
|
||||
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
||||
bytes: lambda o: o.decode(),
|
||||
Color: str,
|
||||
datetime.date: isoformat,
|
||||
datetime.datetime: isoformat,
|
||||
datetime.time: isoformat,
|
||||
datetime.timedelta: lambda td: td.total_seconds(),
|
||||
Decimal: decimal_encoder,
|
||||
Enum: lambda o: o.value,
|
||||
frozenset: list,
|
||||
deque: list,
|
||||
GeneratorType: list,
|
||||
IPv4Address: str,
|
||||
IPv4Interface: str,
|
||||
IPv4Network: str,
|
||||
IPv6Address: str,
|
||||
IPv6Interface: str,
|
||||
IPv6Network: str,
|
||||
NameEmail: str,
|
||||
Path: str,
|
||||
Pattern: lambda o: o.pattern,
|
||||
SecretBytes: str,
|
||||
SecretStr: str,
|
||||
set: list,
|
||||
UUID: str,
|
||||
}
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.',
|
||||
category=None,
|
||||
)
|
||||
def pydantic_encoder(obj: Any) -> Any:
|
||||
warnings.warn(
|
||||
'`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
from dataclasses import asdict, is_dataclass
|
||||
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.model_dump()
|
||||
elif is_dataclass(obj):
|
||||
return asdict(obj) # type: ignore
|
||||
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
try:
|
||||
encoder = ENCODERS_BY_TYPE[base]
|
||||
except KeyError:
|
||||
continue
|
||||
return encoder(obj)
|
||||
else: # We have exited the for loop without finding a suitable encoder
|
||||
raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
|
||||
|
||||
|
||||
# TODO: Add a suggested migration path once there is a way to use custom encoders
|
||||
@deprecated(
|
||||
'`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.',
|
||||
category=None,
|
||||
)
|
||||
def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
|
||||
warnings.warn(
|
||||
'`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
try:
|
||||
encoder = type_encoders[base]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
return encoder(obj)
|
||||
else: # We have exited the for loop without finding a suitable encoder
|
||||
return pydantic_encoder(obj)
|
||||
|
||||
|
||||
@deprecated('`timedelta_isoformat` is deprecated.', category=None)
|
||||
def timedelta_isoformat(td: datetime.timedelta) -> str:
|
||||
"""ISO 8601 encoding for Python timedelta object."""
|
||||
warnings.warn('`timedelta_isoformat` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
|
||||
minutes, seconds = divmod(td.seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'
|
||||
@@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pickle
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
|
||||
class Protocol(str, Enum):
|
||||
json = 'json'
|
||||
pickle = 'pickle'
|
||||
|
||||
|
||||
@deprecated('`load_str_bytes` is deprecated.', category=None)
|
||||
def load_str_bytes(
|
||||
b: str | bytes,
|
||||
*,
|
||||
content_type: str | None = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol | None = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
warnings.warn('`load_str_bytes` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
|
||||
if proto is None and content_type:
|
||||
if content_type.endswith(('json', 'javascript')):
|
||||
pass
|
||||
elif allow_pickle and content_type.endswith('pickle'):
|
||||
proto = Protocol.pickle
|
||||
else:
|
||||
raise TypeError(f'Unknown content-type: {content_type}')
|
||||
|
||||
proto = proto or Protocol.json
|
||||
|
||||
if proto == Protocol.json:
|
||||
if isinstance(b, bytes):
|
||||
b = b.decode(encoding)
|
||||
return json_loads(b) # type: ignore
|
||||
elif proto == Protocol.pickle:
|
||||
if not allow_pickle:
|
||||
raise RuntimeError('Trying to decode with pickle with allow_pickle=False')
|
||||
bb = b if isinstance(b, bytes) else b.encode() # type: ignore
|
||||
return pickle.loads(bb)
|
||||
else:
|
||||
raise TypeError(f'Unknown protocol: {proto}')
|
||||
|
||||
|
||||
@deprecated('`load_file` is deprecated.', category=None)
|
||||
def load_file(
|
||||
path: str | Path,
|
||||
*,
|
||||
content_type: str | None = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol | None = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
warnings.warn('`load_file` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
|
||||
path = Path(path)
|
||||
b = path.read_bytes()
|
||||
if content_type is None:
|
||||
if path.suffix in ('.js', '.json'):
|
||||
proto = Protocol.json
|
||||
elif path.suffix == '.pkl':
|
||||
proto = Protocol.pickle
|
||||
|
||||
return load_str_bytes(
|
||||
b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads
|
||||
)
|
||||
103
venv/lib/python3.11/site-packages/pydantic/deprecated/tools.py
Normal file
103
venv/lib/python3.11/site-packages/pydantic/deprecated/tools.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar, Union
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from ..json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema
|
||||
from ..type_adapter import TypeAdapter
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
__all__ = 'parse_obj_as', 'schema_of', 'schema_json_of'
|
||||
|
||||
NameFactory = Union[str, Callable[[Type[Any]], str]]
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.',
|
||||
category=None,
|
||||
)
|
||||
def parse_obj_as(type_: type[T], obj: Any, type_name: NameFactory | None = None) -> T:
|
||||
warnings.warn(
|
||||
'`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
if type_name is not None: # pragma: no cover
|
||||
warnings.warn(
|
||||
'The type_name parameter is deprecated. parse_obj_as no longer creates temporary models',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return TypeAdapter(type_).validate_python(obj)
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=None,
|
||||
)
|
||||
def schema_of(
|
||||
type_: Any,
|
||||
*,
|
||||
title: NameFactory | None = None,
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate a JSON schema (as dict) for the passed model or dynamically generated one."""
|
||||
warnings.warn(
|
||||
'`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
res = TypeAdapter(type_).json_schema(
|
||||
by_alias=by_alias,
|
||||
schema_generator=schema_generator,
|
||||
ref_template=ref_template,
|
||||
)
|
||||
if title is not None:
|
||||
if isinstance(title, str):
|
||||
res['title'] = title
|
||||
else:
|
||||
warnings.warn(
|
||||
'Passing a callable for the `title` parameter is deprecated and no longer supported',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
res['title'] = title(type_)
|
||||
return res
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=None,
|
||||
)
|
||||
def schema_json_of(
|
||||
type_: Any,
|
||||
*,
|
||||
title: NameFactory | None = None,
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
**dumps_kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate a JSON schema (as JSON) for the passed model or dynamically generated one."""
|
||||
warnings.warn(
|
||||
'`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
return json.dumps(
|
||||
schema_of(type_, title=title, by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator),
|
||||
**dumps_kwargs,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `env_settings` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""The `error_wrappers` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
157
venv/lib/python3.11/site-packages/pydantic/errors.py
Normal file
157
venv/lib/python3.11/site-packages/pydantic/errors.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Pydantic-specific errors."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import re
|
||||
|
||||
from typing_extensions import Literal, Self
|
||||
|
||||
from ._migration import getattr_migration
|
||||
from .version import version_short
|
||||
|
||||
__all__ = (
|
||||
'PydanticUserError',
|
||||
'PydanticUndefinedAnnotation',
|
||||
'PydanticImportError',
|
||||
'PydanticSchemaGenerationError',
|
||||
'PydanticInvalidForJsonSchema',
|
||||
'PydanticErrorCodes',
|
||||
)
|
||||
|
||||
# We use this URL to allow for future flexibility about how we host the docs, while allowing for Pydantic
|
||||
# code in the while with "old" URLs to still work.
|
||||
# 'u' refers to "user errors" - e.g. errors caused by developers using pydantic, as opposed to validation errors.
|
||||
DEV_ERROR_DOCS_URL = f'https://errors.pydantic.dev/{version_short()}/u/'
|
||||
PydanticErrorCodes = Literal[
|
||||
'class-not-fully-defined',
|
||||
'custom-json-schema',
|
||||
'decorator-missing-field',
|
||||
'discriminator-no-field',
|
||||
'discriminator-alias-type',
|
||||
'discriminator-needs-literal',
|
||||
'discriminator-alias',
|
||||
'discriminator-validator',
|
||||
'callable-discriminator-no-tag',
|
||||
'typed-dict-version',
|
||||
'model-field-overridden',
|
||||
'model-field-missing-annotation',
|
||||
'config-both',
|
||||
'removed-kwargs',
|
||||
'invalid-for-json-schema',
|
||||
'json-schema-already-used',
|
||||
'base-model-instantiated',
|
||||
'undefined-annotation',
|
||||
'schema-for-unknown-type',
|
||||
'import-error',
|
||||
'create-model-field-definitions',
|
||||
'create-model-config-base',
|
||||
'validator-no-fields',
|
||||
'validator-invalid-fields',
|
||||
'validator-instance-method',
|
||||
'validator-input-type',
|
||||
'root-validator-pre-skip',
|
||||
'model-serializer-instance-method',
|
||||
'validator-field-config-info',
|
||||
'validator-v1-signature',
|
||||
'validator-signature',
|
||||
'field-serializer-signature',
|
||||
'model-serializer-signature',
|
||||
'multiple-field-serializers',
|
||||
'invalid-annotated-type',
|
||||
'type-adapter-config-unused',
|
||||
'root-model-extra',
|
||||
'unevaluable-type-annotation',
|
||||
'dataclass-init-false-extra-allow',
|
||||
'clashing-init-and-init-var',
|
||||
'model-config-invalid-field-name',
|
||||
'with-config-on-model',
|
||||
'dataclass-on-model',
|
||||
]
|
||||
|
||||
|
||||
class PydanticErrorMixin:
|
||||
"""A mixin class for common functionality shared by all Pydantic-specific errors.
|
||||
|
||||
Attributes:
|
||||
message: A message describing the error.
|
||||
code: An optional error code from PydanticErrorCodes enum.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, *, code: PydanticErrorCodes | None) -> None:
|
||||
self.message = message
|
||||
self.code = code
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.code is None:
|
||||
return self.message
|
||||
else:
|
||||
return f'{self.message}\n\nFor further information visit {DEV_ERROR_DOCS_URL}{self.code}'
|
||||
|
||||
|
||||
class PydanticUserError(PydanticErrorMixin, TypeError):
|
||||
"""An error raised due to incorrect use of Pydantic."""
|
||||
|
||||
|
||||
class PydanticUndefinedAnnotation(PydanticErrorMixin, NameError):
|
||||
"""A subclass of `NameError` raised when handling undefined annotations during `CoreSchema` generation.
|
||||
|
||||
Attributes:
|
||||
name: Name of the error.
|
||||
message: Description of the error.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, message: str) -> None:
|
||||
self.name = name
|
||||
super().__init__(message=message, code='undefined-annotation')
|
||||
|
||||
@classmethod
|
||||
def from_name_error(cls, name_error: NameError) -> Self:
|
||||
"""Convert a `NameError` to a `PydanticUndefinedAnnotation` error.
|
||||
|
||||
Args:
|
||||
name_error: `NameError` to be converted.
|
||||
|
||||
Returns:
|
||||
Converted `PydanticUndefinedAnnotation` error.
|
||||
"""
|
||||
try:
|
||||
name = name_error.name # type: ignore # python > 3.10
|
||||
except AttributeError:
|
||||
name = re.search(r".*'(.+?)'", str(name_error)).group(1) # type: ignore[union-attr]
|
||||
return cls(name=name, message=str(name_error))
|
||||
|
||||
|
||||
class PydanticImportError(PydanticErrorMixin, ImportError):
|
||||
"""An error raised when an import fails due to module changes between V1 and V2.
|
||||
|
||||
Attributes:
|
||||
message: Description of the error.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message, code='import-error')
|
||||
|
||||
|
||||
class PydanticSchemaGenerationError(PydanticUserError):
|
||||
"""An error raised during failures to generate a `CoreSchema` for some type.
|
||||
|
||||
Attributes:
|
||||
message: Description of the error.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message, code='schema-for-unknown-type')
|
||||
|
||||
|
||||
class PydanticInvalidForJsonSchema(PydanticUserError):
|
||||
"""An error raised during failures to generate a JSON schema for some `CoreSchema`.
|
||||
|
||||
Attributes:
|
||||
message: Description of the error.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message, code='invalid-for-json-schema')
|
||||
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
@@ -0,0 +1,10 @@
|
||||
"""The "experimental" module of pydantic contains potential new features that are subject to change."""
|
||||
|
||||
import warnings
|
||||
|
||||
from pydantic.warnings import PydanticExperimentalWarning
|
||||
|
||||
warnings.warn(
|
||||
'This module is experimental, its contents are subject to change and deprecation.',
|
||||
category=PydanticExperimentalWarning,
|
||||
)
|
||||
@@ -0,0 +1,669 @@
|
||||
"""Experimental pipeline API functionality. Be careful with this API, it's subject to change."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import operator
|
||||
import re
|
||||
import sys
|
||||
from collections import deque
|
||||
from collections.abc import Container
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from functools import cached_property, partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Pattern, Protocol, TypeVar, Union, overload
|
||||
|
||||
import annotated_types
|
||||
from typing_extensions import Annotated
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
|
||||
from pydantic._internal._internal_dataclass import slots_true as _slots_true
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
EllipsisType = type(Ellipsis)
|
||||
else:
|
||||
from types import EllipsisType
|
||||
|
||||
__all__ = ['validate_as', 'validate_as_deferred', 'transform']
|
||||
|
||||
_slots_frozen = {**_slots_true, 'frozen': True}
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _ValidateAs:
|
||||
tp: type[Any]
|
||||
strict: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ValidateAsDefer:
|
||||
func: Callable[[], type[Any]]
|
||||
|
||||
@cached_property
|
||||
def tp(self) -> type[Any]:
|
||||
return self.func()
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _Transform:
|
||||
func: Callable[[Any], Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _PipelineOr:
|
||||
left: _Pipeline[Any, Any]
|
||||
right: _Pipeline[Any, Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _PipelineAnd:
|
||||
left: _Pipeline[Any, Any]
|
||||
right: _Pipeline[Any, Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _Eq:
|
||||
value: Any
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _NotEq:
|
||||
value: Any
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _In:
|
||||
values: Container[Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _NotIn:
|
||||
values: Container[Any]
|
||||
|
||||
|
||||
_ConstraintAnnotation = Union[
|
||||
annotated_types.Le,
|
||||
annotated_types.Ge,
|
||||
annotated_types.Lt,
|
||||
annotated_types.Gt,
|
||||
annotated_types.Len,
|
||||
annotated_types.MultipleOf,
|
||||
annotated_types.Timezone,
|
||||
annotated_types.Interval,
|
||||
annotated_types.Predicate,
|
||||
# common predicates not included in annotated_types
|
||||
_Eq,
|
||||
_NotEq,
|
||||
_In,
|
||||
_NotIn,
|
||||
# regular expressions
|
||||
Pattern[str],
|
||||
]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _Constraint:
|
||||
constraint: _ConstraintAnnotation
|
||||
|
||||
|
||||
_Step = Union[_ValidateAs, _ValidateAsDefer, _Transform, _PipelineOr, _PipelineAnd, _Constraint]
|
||||
|
||||
_InT = TypeVar('_InT')
|
||||
_OutT = TypeVar('_OutT')
|
||||
_NewOutT = TypeVar('_NewOutT')
|
||||
|
||||
|
||||
class _FieldTypeMarker:
|
||||
pass
|
||||
|
||||
|
||||
# TODO: ultimately, make this public, see https://github.com/pydantic/pydantic/pull/9459#discussion_r1628197626
|
||||
# Also, make this frozen eventually, but that doesn't work right now because of the generic base
|
||||
# Which attempts to modify __orig_base__ and such.
|
||||
# We could go with a manual freeze, but that seems overkill for now.
|
||||
@dataclass(**_slots_true)
|
||||
class _Pipeline(Generic[_InT, _OutT]):
|
||||
"""Abstract representation of a chain of validation, transformation, and parsing steps."""
|
||||
|
||||
_steps: tuple[_Step, ...]
|
||||
|
||||
def transform(
|
||||
self,
|
||||
func: Callable[[_OutT], _NewOutT],
|
||||
) -> _Pipeline[_InT, _NewOutT]:
|
||||
"""Transform the output of the previous step.
|
||||
|
||||
If used as the first step in a pipeline, the type of the field is used.
|
||||
That is, the transformation is applied to after the value is parsed to the field's type.
|
||||
"""
|
||||
return _Pipeline[_InT, _NewOutT](self._steps + (_Transform(func),))
|
||||
|
||||
@overload
|
||||
def validate_as(self, tp: type[_NewOutT], *, strict: bool = ...) -> _Pipeline[_InT, _NewOutT]: ...
|
||||
|
||||
@overload
|
||||
def validate_as(self, tp: EllipsisType, *, strict: bool = ...) -> _Pipeline[_InT, Any]: # type: ignore
|
||||
...
|
||||
|
||||
def validate_as(self, tp: type[_NewOutT] | EllipsisType, *, strict: bool = False) -> _Pipeline[_InT, Any]: # type: ignore
|
||||
"""Validate / parse the input into a new type.
|
||||
|
||||
If no type is provided, the type of the field is used.
|
||||
|
||||
Types are parsed in Pydantic's `lax` mode by default,
|
||||
but you can enable `strict` mode by passing `strict=True`.
|
||||
"""
|
||||
if isinstance(tp, EllipsisType):
|
||||
return _Pipeline[_InT, Any](self._steps + (_ValidateAs(_FieldTypeMarker, strict=strict),))
|
||||
return _Pipeline[_InT, _NewOutT](self._steps + (_ValidateAs(tp, strict=strict),))
|
||||
|
||||
def validate_as_deferred(self, func: Callable[[], type[_NewOutT]]) -> _Pipeline[_InT, _NewOutT]:
|
||||
"""Parse the input into a new type, deferring resolution of the type until the current class
|
||||
is fully defined.
|
||||
|
||||
This is useful when you need to reference the class in it's own type annotations.
|
||||
"""
|
||||
return _Pipeline[_InT, _NewOutT](self._steps + (_ValidateAsDefer(func),))
|
||||
|
||||
# constraints
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutGe], constraint: annotated_types.Ge) -> _Pipeline[_InT, _NewOutGe]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutGt], constraint: annotated_types.Gt) -> _Pipeline[_InT, _NewOutGt]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutLe], constraint: annotated_types.Le) -> _Pipeline[_InT, _NewOutLe]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutLt], constraint: annotated_types.Lt) -> _Pipeline[_InT, _NewOutLt]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutLen], constraint: annotated_types.Len
|
||||
) -> _Pipeline[_InT, _NewOutLen]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutT], constraint: annotated_types.MultipleOf
|
||||
) -> _Pipeline[_InT, _NewOutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutDatetime], constraint: annotated_types.Timezone
|
||||
) -> _Pipeline[_InT, _NewOutDatetime]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: annotated_types.Predicate) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutInterval], constraint: annotated_types.Interval
|
||||
) -> _Pipeline[_InT, _NewOutInterval]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _Eq) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _NotEq) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _In) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _NotIn) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutT], constraint: Pattern[str]) -> _Pipeline[_InT, _NewOutT]: ...
|
||||
|
||||
def constrain(self, constraint: _ConstraintAnnotation) -> Any:
|
||||
"""Constrain a value to meet a certain condition.
|
||||
|
||||
We support most conditions from `annotated_types`, as well as regular expressions.
|
||||
|
||||
Most of the time you'll be calling a shortcut method like `gt`, `lt`, `len`, etc
|
||||
so you don't need to call this directly.
|
||||
"""
|
||||
return _Pipeline[_InT, _OutT](self._steps + (_Constraint(constraint),))
|
||||
|
||||
def predicate(self: _Pipeline[_InT, _NewOutT], func: Callable[[_NewOutT], bool]) -> _Pipeline[_InT, _NewOutT]:
|
||||
"""Constrain a value to meet a certain predicate."""
|
||||
return self.constrain(annotated_types.Predicate(func))
|
||||
|
||||
def gt(self: _Pipeline[_InT, _NewOutGt], gt: _NewOutGt) -> _Pipeline[_InT, _NewOutGt]:
|
||||
"""Constrain a value to be greater than a certain value."""
|
||||
return self.constrain(annotated_types.Gt(gt))
|
||||
|
||||
def lt(self: _Pipeline[_InT, _NewOutLt], lt: _NewOutLt) -> _Pipeline[_InT, _NewOutLt]:
|
||||
"""Constrain a value to be less than a certain value."""
|
||||
return self.constrain(annotated_types.Lt(lt))
|
||||
|
||||
def ge(self: _Pipeline[_InT, _NewOutGe], ge: _NewOutGe) -> _Pipeline[_InT, _NewOutGe]:
|
||||
"""Constrain a value to be greater than or equal to a certain value."""
|
||||
return self.constrain(annotated_types.Ge(ge))
|
||||
|
||||
def le(self: _Pipeline[_InT, _NewOutLe], le: _NewOutLe) -> _Pipeline[_InT, _NewOutLe]:
|
||||
"""Constrain a value to be less than or equal to a certain value."""
|
||||
return self.constrain(annotated_types.Le(le))
|
||||
|
||||
def len(self: _Pipeline[_InT, _NewOutLen], min_len: int, max_len: int | None = None) -> _Pipeline[_InT, _NewOutLen]:
|
||||
"""Constrain a value to have a certain length."""
|
||||
return self.constrain(annotated_types.Len(min_len, max_len))
|
||||
|
||||
@overload
|
||||
def multiple_of(self: _Pipeline[_InT, _NewOutDiv], multiple_of: _NewOutDiv) -> _Pipeline[_InT, _NewOutDiv]: ...
|
||||
|
||||
@overload
|
||||
def multiple_of(self: _Pipeline[_InT, _NewOutMod], multiple_of: _NewOutMod) -> _Pipeline[_InT, _NewOutMod]: ...
|
||||
|
||||
def multiple_of(self: _Pipeline[_InT, Any], multiple_of: Any) -> _Pipeline[_InT, Any]:
|
||||
"""Constrain a value to be a multiple of a certain number."""
|
||||
return self.constrain(annotated_types.MultipleOf(multiple_of))
|
||||
|
||||
def eq(self: _Pipeline[_InT, _OutT], value: _OutT) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to be equal to a certain value."""
|
||||
return self.constrain(_Eq(value))
|
||||
|
||||
def not_eq(self: _Pipeline[_InT, _OutT], value: _OutT) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to not be equal to a certain value."""
|
||||
return self.constrain(_NotEq(value))
|
||||
|
||||
def in_(self: _Pipeline[_InT, _OutT], values: Container[_OutT]) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to be in a certain set."""
|
||||
return self.constrain(_In(values))
|
||||
|
||||
def not_in(self: _Pipeline[_InT, _OutT], values: Container[_OutT]) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to not be in a certain set."""
|
||||
return self.constrain(_NotIn(values))
|
||||
|
||||
# timezone methods
|
||||
def datetime_tz_naive(self: _Pipeline[_InT, datetime.datetime]) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.constrain(annotated_types.Timezone(None))
|
||||
|
||||
def datetime_tz_aware(self: _Pipeline[_InT, datetime.datetime]) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.constrain(annotated_types.Timezone(...))
|
||||
|
||||
def datetime_tz(
|
||||
self: _Pipeline[_InT, datetime.datetime], tz: datetime.tzinfo
|
||||
) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.constrain(annotated_types.Timezone(tz)) # type: ignore
|
||||
|
||||
def datetime_with_tz(
|
||||
self: _Pipeline[_InT, datetime.datetime], tz: datetime.tzinfo | None
|
||||
) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.transform(partial(datetime.datetime.replace, tzinfo=tz))
|
||||
|
||||
# string methods
|
||||
def str_lower(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.lower)
|
||||
|
||||
def str_upper(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.upper)
|
||||
|
||||
def str_title(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.title)
|
||||
|
||||
def str_strip(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.strip)
|
||||
|
||||
def str_pattern(self: _Pipeline[_InT, str], pattern: str) -> _Pipeline[_InT, str]:
|
||||
return self.constrain(re.compile(pattern))
|
||||
|
||||
def str_contains(self: _Pipeline[_InT, str], substring: str) -> _Pipeline[_InT, str]:
|
||||
return self.predicate(lambda v: substring in v)
|
||||
|
||||
def str_starts_with(self: _Pipeline[_InT, str], prefix: str) -> _Pipeline[_InT, str]:
|
||||
return self.predicate(lambda v: v.startswith(prefix))
|
||||
|
||||
def str_ends_with(self: _Pipeline[_InT, str], suffix: str) -> _Pipeline[_InT, str]:
|
||||
return self.predicate(lambda v: v.endswith(suffix))
|
||||
|
||||
# operators
|
||||
def otherwise(self, other: _Pipeline[_OtherIn, _OtherOut]) -> _Pipeline[_InT | _OtherIn, _OutT | _OtherOut]:
|
||||
"""Combine two validation chains, returning the result of the first chain if it succeeds, and the second chain if it fails."""
|
||||
return _Pipeline((_PipelineOr(self, other),))
|
||||
|
||||
__or__ = otherwise
|
||||
|
||||
def then(self, other: _Pipeline[_OutT, _OtherOut]) -> _Pipeline[_InT, _OtherOut]:
|
||||
"""Pipe the result of one validation chain into another."""
|
||||
return _Pipeline((_PipelineAnd(self, other),))
|
||||
|
||||
__and__ = then
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
queue = deque(self._steps)
|
||||
|
||||
s = None
|
||||
|
||||
while queue:
|
||||
step = queue.popleft()
|
||||
s = _apply_step(step, s, handler, source_type)
|
||||
|
||||
s = s or cs.any_schema()
|
||||
return s
|
||||
|
||||
def __supports_type__(self, _: _OutT) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
validate_as = _Pipeline[Any, Any](()).validate_as
|
||||
validate_as_deferred = _Pipeline[Any, Any](()).validate_as_deferred
|
||||
transform = _Pipeline[Any, Any]((_ValidateAs(_FieldTypeMarker),)).transform
|
||||
|
||||
|
||||
def _check_func(
|
||||
func: Callable[[Any], bool], predicate_err: str | Callable[[], str], s: cs.CoreSchema | None
|
||||
) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
def handler(v: Any) -> Any:
|
||||
if func(v):
|
||||
return v
|
||||
raise ValueError(f'Expected {predicate_err if isinstance(predicate_err, str) else predicate_err()}')
|
||||
|
||||
if s is None:
|
||||
return cs.no_info_plain_validator_function(handler)
|
||||
else:
|
||||
return cs.no_info_after_validator_function(handler, s)
|
||||
|
||||
|
||||
def _apply_step(step: _Step, s: cs.CoreSchema | None, handler: GetCoreSchemaHandler, source_type: Any) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
if isinstance(step, _ValidateAs):
|
||||
s = _apply_parse(s, step.tp, step.strict, handler, source_type)
|
||||
elif isinstance(step, _ValidateAsDefer):
|
||||
s = _apply_parse(s, step.tp, False, handler, source_type)
|
||||
elif isinstance(step, _Transform):
|
||||
s = _apply_transform(s, step.func, handler)
|
||||
elif isinstance(step, _Constraint):
|
||||
s = _apply_constraint(s, step.constraint)
|
||||
elif isinstance(step, _PipelineOr):
|
||||
s = cs.union_schema([handler(step.left), handler(step.right)])
|
||||
else:
|
||||
assert isinstance(step, _PipelineAnd)
|
||||
s = cs.chain_schema([handler(step.left), handler(step.right)])
|
||||
return s
|
||||
|
||||
|
||||
def _apply_parse(
|
||||
s: cs.CoreSchema | None,
|
||||
tp: type[Any],
|
||||
strict: bool,
|
||||
handler: GetCoreSchemaHandler,
|
||||
source_type: Any,
|
||||
) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
from pydantic import Strict
|
||||
|
||||
if tp is _FieldTypeMarker:
|
||||
return handler(source_type)
|
||||
|
||||
if strict:
|
||||
tp = Annotated[tp, Strict()] # type: ignore
|
||||
|
||||
if s and s['type'] == 'any':
|
||||
return handler(tp)
|
||||
else:
|
||||
return cs.chain_schema([s, handler(tp)]) if s else handler(tp)
|
||||
|
||||
|
||||
def _apply_transform(
|
||||
s: cs.CoreSchema | None, func: Callable[[Any], Any], handler: GetCoreSchemaHandler
|
||||
) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
if s is None:
|
||||
return cs.no_info_plain_validator_function(func)
|
||||
|
||||
if s['type'] == 'str':
|
||||
if func is str.strip:
|
||||
s = s.copy()
|
||||
s['strip_whitespace'] = True
|
||||
return s
|
||||
elif func is str.lower:
|
||||
s = s.copy()
|
||||
s['to_lower'] = True
|
||||
return s
|
||||
elif func is str.upper:
|
||||
s = s.copy()
|
||||
s['to_upper'] = True
|
||||
return s
|
||||
|
||||
return cs.no_info_after_validator_function(func, s)
|
||||
|
||||
|
||||
def _apply_constraint( # noqa: C901
|
||||
s: cs.CoreSchema | None, constraint: _ConstraintAnnotation
|
||||
) -> cs.CoreSchema:
|
||||
"""Apply a single constraint to a schema."""
|
||||
if isinstance(constraint, annotated_types.Gt):
|
||||
gt = constraint.gt
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(gt, int):
|
||||
s['gt'] = gt
|
||||
elif s['type'] == 'float' and isinstance(gt, float):
|
||||
s['gt'] = gt
|
||||
elif s['type'] == 'decimal' and isinstance(gt, Decimal):
|
||||
s['gt'] = gt
|
||||
else:
|
||||
|
||||
def check_gt(v: Any) -> bool:
|
||||
return v > gt
|
||||
|
||||
s = _check_func(check_gt, f'> {gt}', s)
|
||||
elif isinstance(constraint, annotated_types.Ge):
|
||||
ge = constraint.ge
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(ge, int):
|
||||
s['ge'] = ge
|
||||
elif s['type'] == 'float' and isinstance(ge, float):
|
||||
s['ge'] = ge
|
||||
elif s['type'] == 'decimal' and isinstance(ge, Decimal):
|
||||
s['ge'] = ge
|
||||
|
||||
def check_ge(v: Any) -> bool:
|
||||
return v >= ge
|
||||
|
||||
s = _check_func(check_ge, f'>= {ge}', s)
|
||||
elif isinstance(constraint, annotated_types.Lt):
|
||||
lt = constraint.lt
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(lt, int):
|
||||
s['lt'] = lt
|
||||
elif s['type'] == 'float' and isinstance(lt, float):
|
||||
s['lt'] = lt
|
||||
elif s['type'] == 'decimal' and isinstance(lt, Decimal):
|
||||
s['lt'] = lt
|
||||
|
||||
def check_lt(v: Any) -> bool:
|
||||
return v < lt
|
||||
|
||||
s = _check_func(check_lt, f'< {lt}', s)
|
||||
elif isinstance(constraint, annotated_types.Le):
|
||||
le = constraint.le
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(le, int):
|
||||
s['le'] = le
|
||||
elif s['type'] == 'float' and isinstance(le, float):
|
||||
s['le'] = le
|
||||
elif s['type'] == 'decimal' and isinstance(le, Decimal):
|
||||
s['le'] = le
|
||||
|
||||
def check_le(v: Any) -> bool:
|
||||
return v <= le
|
||||
|
||||
s = _check_func(check_le, f'<= {le}', s)
|
||||
elif isinstance(constraint, annotated_types.Len):
|
||||
min_len = constraint.min_length
|
||||
max_len = constraint.max_length
|
||||
|
||||
if s and s['type'] in {'str', 'list', 'tuple', 'set', 'frozenset', 'dict'}:
|
||||
assert (
|
||||
s['type'] == 'str'
|
||||
or s['type'] == 'list'
|
||||
or s['type'] == 'tuple'
|
||||
or s['type'] == 'set'
|
||||
or s['type'] == 'dict'
|
||||
or s['type'] == 'frozenset'
|
||||
)
|
||||
s = s.copy()
|
||||
if min_len != 0:
|
||||
s['min_length'] = min_len
|
||||
if max_len is not None:
|
||||
s['max_length'] = max_len
|
||||
|
||||
def check_len(v: Any) -> bool:
|
||||
if max_len is not None:
|
||||
return (min_len <= len(v)) and (len(v) <= max_len)
|
||||
return min_len <= len(v)
|
||||
|
||||
s = _check_func(check_len, f'length >= {min_len} and length <= {max_len}', s)
|
||||
elif isinstance(constraint, annotated_types.MultipleOf):
|
||||
multiple_of = constraint.multiple_of
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(multiple_of, int):
|
||||
s['multiple_of'] = multiple_of
|
||||
elif s['type'] == 'float' and isinstance(multiple_of, float):
|
||||
s['multiple_of'] = multiple_of
|
||||
elif s['type'] == 'decimal' and isinstance(multiple_of, Decimal):
|
||||
s['multiple_of'] = multiple_of
|
||||
|
||||
def check_multiple_of(v: Any) -> bool:
|
||||
return v % multiple_of == 0
|
||||
|
||||
s = _check_func(check_multiple_of, f'% {multiple_of} == 0', s)
|
||||
elif isinstance(constraint, annotated_types.Timezone):
|
||||
tz = constraint.tz
|
||||
|
||||
if tz is ...:
|
||||
if s and s['type'] == 'datetime':
|
||||
s = s.copy()
|
||||
s['tz_constraint'] = 'aware'
|
||||
else:
|
||||
|
||||
def check_tz_aware(v: object) -> bool:
|
||||
assert isinstance(v, datetime.datetime)
|
||||
return v.tzinfo is not None
|
||||
|
||||
s = _check_func(check_tz_aware, 'timezone aware', s)
|
||||
elif tz is None:
|
||||
if s and s['type'] == 'datetime':
|
||||
s = s.copy()
|
||||
s['tz_constraint'] = 'naive'
|
||||
else:
|
||||
|
||||
def check_tz_naive(v: object) -> bool:
|
||||
assert isinstance(v, datetime.datetime)
|
||||
return v.tzinfo is None
|
||||
|
||||
s = _check_func(check_tz_naive, 'timezone naive', s)
|
||||
else:
|
||||
raise NotImplementedError('Constraining to a specific timezone is not yet supported')
|
||||
elif isinstance(constraint, annotated_types.Interval):
|
||||
if constraint.ge:
|
||||
s = _apply_constraint(s, annotated_types.Ge(constraint.ge))
|
||||
if constraint.gt:
|
||||
s = _apply_constraint(s, annotated_types.Gt(constraint.gt))
|
||||
if constraint.le:
|
||||
s = _apply_constraint(s, annotated_types.Le(constraint.le))
|
||||
if constraint.lt:
|
||||
s = _apply_constraint(s, annotated_types.Lt(constraint.lt))
|
||||
assert s is not None
|
||||
elif isinstance(constraint, annotated_types.Predicate):
|
||||
func = constraint.func
|
||||
|
||||
if func.__name__ == '<lambda>':
|
||||
# attempt to extract the source code for a lambda function
|
||||
# to use as the function name in error messages
|
||||
# TODO: is there a better way? should we just not do this?
|
||||
import inspect
|
||||
|
||||
try:
|
||||
# remove ')' suffix, can use removesuffix once we drop 3.8
|
||||
source = inspect.getsource(func).strip()
|
||||
if source.endswith(')'):
|
||||
source = source[:-1]
|
||||
lambda_source_code = '`' + ''.join(''.join(source.split('lambda ')[1:]).split(':')[1:]).strip() + '`'
|
||||
except OSError:
|
||||
# stringified annotations
|
||||
lambda_source_code = 'lambda'
|
||||
|
||||
s = _check_func(func, lambda_source_code, s)
|
||||
else:
|
||||
s = _check_func(func, func.__name__, s)
|
||||
elif isinstance(constraint, _NotEq):
|
||||
value = constraint.value
|
||||
|
||||
def check_not_eq(v: Any) -> bool:
|
||||
return operator.__ne__(v, value)
|
||||
|
||||
s = _check_func(check_not_eq, f'!= {value}', s)
|
||||
elif isinstance(constraint, _Eq):
|
||||
value = constraint.value
|
||||
|
||||
def check_eq(v: Any) -> bool:
|
||||
return operator.__eq__(v, value)
|
||||
|
||||
s = _check_func(check_eq, f'== {value}', s)
|
||||
elif isinstance(constraint, _In):
|
||||
values = constraint.values
|
||||
|
||||
def check_in(v: Any) -> bool:
|
||||
return operator.__contains__(values, v)
|
||||
|
||||
s = _check_func(check_in, f'in {values}', s)
|
||||
elif isinstance(constraint, _NotIn):
|
||||
values = constraint.values
|
||||
|
||||
def check_not_in(v: Any) -> bool:
|
||||
return operator.__not__(operator.__contains__(values, v))
|
||||
|
||||
s = _check_func(check_not_in, f'not in {values}', s)
|
||||
else:
|
||||
assert isinstance(constraint, Pattern)
|
||||
if s and s['type'] == 'str':
|
||||
s = s.copy()
|
||||
s['pattern'] = constraint.pattern
|
||||
else:
|
||||
|
||||
def check_pattern(v: object) -> bool:
|
||||
assert isinstance(v, str)
|
||||
return constraint.match(v) is not None
|
||||
|
||||
s = _check_func(check_pattern, f'~ {constraint.pattern}', s)
|
||||
return s
|
||||
|
||||
|
||||
class _SupportsRange(annotated_types.SupportsLe, annotated_types.SupportsGe, Protocol):
|
||||
pass
|
||||
|
||||
|
||||
class _SupportsLen(Protocol):
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
|
||||
_NewOutGt = TypeVar('_NewOutGt', bound=annotated_types.SupportsGt)
|
||||
_NewOutGe = TypeVar('_NewOutGe', bound=annotated_types.SupportsGe)
|
||||
_NewOutLt = TypeVar('_NewOutLt', bound=annotated_types.SupportsLt)
|
||||
_NewOutLe = TypeVar('_NewOutLe', bound=annotated_types.SupportsLe)
|
||||
_NewOutLen = TypeVar('_NewOutLen', bound=_SupportsLen)
|
||||
_NewOutDiv = TypeVar('_NewOutDiv', bound=annotated_types.SupportsDiv)
|
||||
_NewOutMod = TypeVar('_NewOutMod', bound=annotated_types.SupportsMod)
|
||||
_NewOutDatetime = TypeVar('_NewOutDatetime', bound=datetime.datetime)
|
||||
_NewOutInterval = TypeVar('_NewOutInterval', bound=_SupportsRange)
|
||||
_OtherIn = TypeVar('_OtherIn')
|
||||
_OtherOut = TypeVar('_OtherOut')
|
||||
1266
venv/lib/python3.11/site-packages/pydantic/fields.py
Normal file
1266
venv/lib/python3.11/site-packages/pydantic/fields.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,438 @@
|
||||
"""This module contains related classes and functions for serialization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from functools import partial, partialmethod
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
|
||||
|
||||
from pydantic_core import PydanticUndefined, core_schema
|
||||
from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapHandler, WhenUsed
|
||||
from typing_extensions import Annotated, Literal, TypeAlias
|
||||
|
||||
from . import PydanticUndefinedAnnotation
|
||||
from ._internal import _decorators, _internal_dataclass
|
||||
from .annotated_handlers import GetCoreSchemaHandler
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True)
|
||||
class PlainSerializer:
|
||||
"""Plain serializers use a function to modify the output of serialization.
|
||||
|
||||
This is particularly helpful when you want to customize the serialization for annotated types.
|
||||
Consider an input of `list`, which will be serialized into a space-delimited string.
|
||||
|
||||
```python
|
||||
from typing import List
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, PlainSerializer
|
||||
|
||||
CustomStr = Annotated[
|
||||
List, PlainSerializer(lambda x: ' '.join(x), return_type=str)
|
||||
]
|
||||
|
||||
class StudentModel(BaseModel):
|
||||
courses: CustomStr
|
||||
|
||||
student = StudentModel(courses=['Math', 'Chemistry', 'English'])
|
||||
print(student.model_dump())
|
||||
#> {'courses': 'Math Chemistry English'}
|
||||
```
|
||||
|
||||
Attributes:
|
||||
func: The serializer function.
|
||||
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
|
||||
when_used: Determines when this serializer should be used. Accepts a string with values `'always'`,
|
||||
`'unless-none'`, `'json'`, and `'json-unless-none'`. Defaults to 'always'.
|
||||
"""
|
||||
|
||||
func: core_schema.SerializerFunction
|
||||
return_type: Any = PydanticUndefined
|
||||
when_used: WhenUsed = 'always'
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
"""Gets the Pydantic core schema.
|
||||
|
||||
Args:
|
||||
source_type: The source type.
|
||||
handler: The `GetCoreSchemaHandler` instance.
|
||||
|
||||
Returns:
|
||||
The Pydantic core schema.
|
||||
"""
|
||||
schema = handler(source_type)
|
||||
try:
|
||||
return_type = _decorators.get_function_return_type(
|
||||
self.func, self.return_type, handler._get_types_namespace()
|
||||
)
|
||||
except NameError as e:
|
||||
raise PydanticUndefinedAnnotation.from_name_error(e) from e
|
||||
return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type)
|
||||
schema['serialization'] = core_schema.plain_serializer_function_ser_schema(
|
||||
function=self.func,
|
||||
info_arg=_decorators.inspect_annotated_serializer(self.func, 'plain'),
|
||||
return_schema=return_schema,
|
||||
when_used=self.when_used,
|
||||
)
|
||||
return schema
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True)
|
||||
class WrapSerializer:
|
||||
"""Wrap serializers receive the raw inputs along with a handler function that applies the standard serialization
|
||||
logic, and can modify the resulting value before returning it as the final output of serialization.
|
||||
|
||||
For example, here's a scenario in which a wrap serializer transforms timezones to UTC **and** utilizes the existing `datetime` serialization logic.
|
||||
|
||||
```python
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, WrapSerializer
|
||||
|
||||
class EventDatetime(BaseModel):
|
||||
start: datetime
|
||||
end: datetime
|
||||
|
||||
def convert_to_utc(value: Any, handler, info) -> Dict[str, datetime]:
|
||||
# Note that `handler` can actually help serialize the `value` for
|
||||
# further custom serialization in case it's a subclass.
|
||||
partial_result = handler(value, info)
|
||||
if info.mode == 'json':
|
||||
return {
|
||||
k: datetime.fromisoformat(v).astimezone(timezone.utc)
|
||||
for k, v in partial_result.items()
|
||||
}
|
||||
return {k: v.astimezone(timezone.utc) for k, v in partial_result.items()}
|
||||
|
||||
UTCEventDatetime = Annotated[EventDatetime, WrapSerializer(convert_to_utc)]
|
||||
|
||||
class EventModel(BaseModel):
|
||||
event_datetime: UTCEventDatetime
|
||||
|
||||
dt = EventDatetime(
|
||||
start='2024-01-01T07:00:00-08:00', end='2024-01-03T20:00:00+06:00'
|
||||
)
|
||||
event = EventModel(event_datetime=dt)
|
||||
print(event.model_dump())
|
||||
'''
|
||||
{
|
||||
'event_datetime': {
|
||||
'start': datetime.datetime(
|
||||
2024, 1, 1, 15, 0, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
'end': datetime.datetime(
|
||||
2024, 1, 3, 14, 0, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
print(event.model_dump_json())
|
||||
'''
|
||||
{"event_datetime":{"start":"2024-01-01T15:00:00Z","end":"2024-01-03T14:00:00Z"}}
|
||||
'''
|
||||
```
|
||||
|
||||
Attributes:
|
||||
func: The serializer function to be wrapped.
|
||||
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
|
||||
when_used: Determines when this serializer should be used. Accepts a string with values `'always'`,
|
||||
`'unless-none'`, `'json'`, and `'json-unless-none'`. Defaults to 'always'.
|
||||
"""
|
||||
|
||||
func: core_schema.WrapSerializerFunction
|
||||
return_type: Any = PydanticUndefined
|
||||
when_used: WhenUsed = 'always'
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
"""This method is used to get the Pydantic core schema of the class.
|
||||
|
||||
Args:
|
||||
source_type: Source type.
|
||||
handler: Core schema handler.
|
||||
|
||||
Returns:
|
||||
The generated core schema of the class.
|
||||
"""
|
||||
schema = handler(source_type)
|
||||
try:
|
||||
return_type = _decorators.get_function_return_type(
|
||||
self.func, self.return_type, handler._get_types_namespace()
|
||||
)
|
||||
except NameError as e:
|
||||
raise PydanticUndefinedAnnotation.from_name_error(e) from e
|
||||
return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type)
|
||||
schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
|
||||
function=self.func,
|
||||
info_arg=_decorators.inspect_annotated_serializer(self.func, 'wrap'),
|
||||
return_schema=return_schema,
|
||||
when_used=self.when_used,
|
||||
)
|
||||
return schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
_Partial: TypeAlias = 'partial[Any] | partialmethod[Any]'
|
||||
|
||||
FieldPlainSerializer: TypeAlias = 'core_schema.SerializerFunction | _Partial'
|
||||
"""A field serializer method or function in `plain` mode."""
|
||||
|
||||
FieldWrapSerializer: TypeAlias = 'core_schema.WrapSerializerFunction | _Partial'
|
||||
"""A field serializer method or function in `wrap` mode."""
|
||||
|
||||
FieldSerializer: TypeAlias = 'FieldPlainSerializer | FieldWrapSerializer'
|
||||
"""A field serializer method or function."""
|
||||
|
||||
_FieldPlainSerializerT = TypeVar('_FieldPlainSerializerT', bound=FieldPlainSerializer)
|
||||
_FieldWrapSerializerT = TypeVar('_FieldWrapSerializerT', bound=FieldWrapSerializer)
|
||||
|
||||
|
||||
@overload
|
||||
def field_serializer(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['wrap'],
|
||||
return_type: Any = ...,
|
||||
when_used: WhenUsed = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_FieldWrapSerializerT], _FieldWrapSerializerT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def field_serializer(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['plain'] = ...,
|
||||
return_type: Any = ...,
|
||||
when_used: WhenUsed = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_FieldPlainSerializerT], _FieldPlainSerializerT]: ...
|
||||
|
||||
|
||||
def field_serializer(
|
||||
*fields: str,
|
||||
mode: Literal['plain', 'wrap'] = 'plain',
|
||||
return_type: Any = PydanticUndefined,
|
||||
when_used: WhenUsed = 'always',
|
||||
check_fields: bool | None = None,
|
||||
) -> (
|
||||
Callable[[_FieldWrapSerializerT], _FieldWrapSerializerT]
|
||||
| Callable[[_FieldPlainSerializerT], _FieldPlainSerializerT]
|
||||
):
|
||||
"""Decorator that enables custom field serialization.
|
||||
|
||||
In the below example, a field of type `set` is used to mitigate duplication. A `field_serializer` is used to serialize the data as a sorted list.
|
||||
|
||||
```python
|
||||
from typing import Set
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
class StudentModel(BaseModel):
|
||||
name: str = 'Jane'
|
||||
courses: Set[str]
|
||||
|
||||
@field_serializer('courses', when_used='json')
|
||||
def serialize_courses_in_order(self, courses: Set[str]):
|
||||
return sorted(courses)
|
||||
|
||||
student = StudentModel(courses={'Math', 'Chemistry', 'English'})
|
||||
print(student.model_dump_json())
|
||||
#> {"name":"Jane","courses":["Chemistry","English","Math"]}
|
||||
```
|
||||
|
||||
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
|
||||
|
||||
Four signatures are supported:
|
||||
|
||||
- `(self, value: Any, info: FieldSerializationInfo)`
|
||||
- `(self, value: Any, nxt: SerializerFunctionWrapHandler, info: FieldSerializationInfo)`
|
||||
- `(value: Any, info: SerializationInfo)`
|
||||
- `(value: Any, nxt: SerializerFunctionWrapHandler, info: SerializationInfo)`
|
||||
|
||||
Args:
|
||||
fields: Which field(s) the method should be called on.
|
||||
mode: The serialization mode.
|
||||
|
||||
- `plain` means the function will be called instead of the default serialization logic,
|
||||
- `wrap` means the function will be called with an argument to optionally call the
|
||||
default serialization logic.
|
||||
return_type: Optional return type for the function, if omitted it will be inferred from the type annotation.
|
||||
when_used: Determines the serializer will be used for serialization.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
|
||||
Returns:
|
||||
The decorator function.
|
||||
"""
|
||||
|
||||
def dec(f: FieldSerializer) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
dec_info = _decorators.FieldSerializerDecoratorInfo(
|
||||
fields=fields,
|
||||
mode=mode,
|
||||
return_type=return_type,
|
||||
when_used=when_used,
|
||||
check_fields=check_fields,
|
||||
)
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info) # pyright: ignore[reportArgumentType]
|
||||
|
||||
return dec # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# The first argument in the following callables represent the `self` type:
|
||||
|
||||
ModelPlainSerializerWithInfo: TypeAlias = Callable[[Any, SerializationInfo], Any]
|
||||
"""A model serializer method with the `info` argument, in `plain` mode."""
|
||||
|
||||
ModelPlainSerializerWithoutInfo: TypeAlias = Callable[[Any], Any]
|
||||
"""A model serializer method without the `info` argument, in `plain` mode."""
|
||||
|
||||
ModelPlainSerializer: TypeAlias = 'ModelPlainSerializerWithInfo | ModelPlainSerializerWithoutInfo'
|
||||
"""A model serializer method in `plain` mode."""
|
||||
|
||||
ModelWrapSerializerWithInfo: TypeAlias = Callable[[Any, SerializerFunctionWrapHandler, SerializationInfo], Any]
|
||||
"""A model serializer method with the `info` argument, in `wrap` mode."""
|
||||
|
||||
ModelWrapSerializerWithoutInfo: TypeAlias = Callable[[Any, SerializerFunctionWrapHandler], Any]
|
||||
"""A model serializer method without the `info` argument, in `wrap` mode."""
|
||||
|
||||
ModelWrapSerializer: TypeAlias = 'ModelWrapSerializerWithInfo | ModelWrapSerializerWithoutInfo'
|
||||
"""A model serializer method in `wrap` mode."""
|
||||
|
||||
ModelSerializer: TypeAlias = 'ModelPlainSerializer | ModelWrapSerializer'
|
||||
|
||||
_ModelPlainSerializerT = TypeVar('_ModelPlainSerializerT', bound=ModelPlainSerializer)
|
||||
_ModelWrapSerializerT = TypeVar('_ModelWrapSerializerT', bound=ModelWrapSerializer)
|
||||
|
||||
|
||||
@overload
|
||||
def model_serializer(f: _ModelPlainSerializerT, /) -> _ModelPlainSerializerT: ...
|
||||
|
||||
|
||||
@overload
|
||||
def model_serializer(
|
||||
*, mode: Literal['wrap'], when_used: WhenUsed = 'always', return_type: Any = ...
|
||||
) -> Callable[[_ModelWrapSerializerT], _ModelWrapSerializerT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def model_serializer(
|
||||
*,
|
||||
mode: Literal['plain'] = ...,
|
||||
when_used: WhenUsed = 'always',
|
||||
return_type: Any = ...,
|
||||
) -> Callable[[_ModelPlainSerializerT], _ModelPlainSerializerT]: ...
|
||||
|
||||
|
||||
def model_serializer(
|
||||
f: _ModelPlainSerializerT | _ModelWrapSerializerT | None = None,
|
||||
/,
|
||||
*,
|
||||
mode: Literal['plain', 'wrap'] = 'plain',
|
||||
when_used: WhenUsed = 'always',
|
||||
return_type: Any = PydanticUndefined,
|
||||
) -> (
|
||||
_ModelPlainSerializerT
|
||||
| Callable[[_ModelWrapSerializerT], _ModelWrapSerializerT]
|
||||
| Callable[[_ModelPlainSerializerT], _ModelPlainSerializerT]
|
||||
):
|
||||
"""Decorator that enables custom model serialization.
|
||||
|
||||
This is useful when a model need to be serialized in a customized manner, allowing for flexibility beyond just specific fields.
|
||||
|
||||
An example would be to serialize temperature to the same temperature scale, such as degrees Celsius.
|
||||
|
||||
```python
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, model_serializer
|
||||
|
||||
class TemperatureModel(BaseModel):
|
||||
unit: Literal['C', 'F']
|
||||
value: int
|
||||
|
||||
@model_serializer()
|
||||
def serialize_model(self):
|
||||
if self.unit == 'F':
|
||||
return {'unit': 'C', 'value': int((self.value - 32) / 1.8)}
|
||||
return {'unit': self.unit, 'value': self.value}
|
||||
|
||||
temperature = TemperatureModel(unit='F', value=212)
|
||||
print(temperature.model_dump())
|
||||
#> {'unit': 'C', 'value': 100}
|
||||
```
|
||||
|
||||
Two signatures are supported for `mode='plain'`, which is the default:
|
||||
|
||||
- `(self)`
|
||||
- `(self, info: SerializationInfo)`
|
||||
|
||||
And two other signatures for `mode='wrap'`:
|
||||
|
||||
- `(self, nxt: SerializerFunctionWrapHandler)`
|
||||
- `(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo)`
|
||||
|
||||
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
|
||||
|
||||
Args:
|
||||
f: The function to be decorated.
|
||||
mode: The serialization mode.
|
||||
|
||||
- `'plain'` means the function will be called instead of the default serialization logic
|
||||
- `'wrap'` means the function will be called with an argument to optionally call the default
|
||||
serialization logic.
|
||||
when_used: Determines when this serializer should be used.
|
||||
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
|
||||
|
||||
Returns:
|
||||
The decorator function.
|
||||
"""
|
||||
|
||||
def dec(f: ModelSerializer) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
dec_info = _decorators.ModelSerializerDecoratorInfo(mode=mode, return_type=return_type, when_used=when_used)
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info)
|
||||
|
||||
if f is None:
|
||||
return dec # pyright: ignore[reportReturnType]
|
||||
else:
|
||||
return dec(f) # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
AnyType = TypeVar('AnyType')
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
SerializeAsAny = Annotated[AnyType, ...] # SerializeAsAny[list[str]] will be treated by type checkers as list[str]
|
||||
"""Force serialization to ignore whatever is defined in the schema and instead ask the object
|
||||
itself how it should be serialized.
|
||||
In particular, this means that when model subclasses are serialized, fields present in the subclass
|
||||
but not in the original schema will be included.
|
||||
"""
|
||||
else:
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class SerializeAsAny: # noqa: D101
|
||||
def __class_getitem__(cls, item: Any) -> Any:
|
||||
return Annotated[item, SerializeAsAny()]
|
||||
|
||||
def __get_pydantic_core_schema__(
|
||||
self, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = handler(source_type)
|
||||
schema_to_update = schema
|
||||
while schema_to_update['type'] == 'definitions':
|
||||
schema_to_update = schema_to_update.copy()
|
||||
schema_to_update = schema_to_update['schema']
|
||||
schema_to_update['serialization'] = core_schema.wrap_serializer_function_ser_schema(
|
||||
lambda x, h: h(x), schema=core_schema.any_schema()
|
||||
)
|
||||
return schema
|
||||
|
||||
__hash__ = object.__hash__
|
||||
@@ -0,0 +1,808 @@
|
||||
"""This module contains related classes and functions for validation."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
from functools import partialmethod
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast, overload
|
||||
|
||||
from pydantic_core import PydanticUndefined, core_schema
|
||||
from pydantic_core import core_schema as _core_schema
|
||||
from typing_extensions import Annotated, Literal, Self, TypeAlias
|
||||
|
||||
from ._internal import _core_metadata, _decorators, _generics, _internal_dataclass
|
||||
from .annotated_handlers import GetCoreSchemaHandler
|
||||
from .errors import PydanticUserError
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import Protocol
|
||||
else:
|
||||
from typing import Protocol
|
||||
|
||||
_inspect_validator = _decorators.inspect_validator
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class AfterValidator:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/validators/#annotated-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **after** the inner validation logic.
|
||||
|
||||
Attributes:
|
||||
func: The validator function.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, ValidationError
|
||||
|
||||
MyInt = Annotated[int, AfterValidator(lambda v: v + 1)]
|
||||
|
||||
class Model(BaseModel):
|
||||
a: MyInt
|
||||
|
||||
print(Model(a=1).a)
|
||||
#> 2
|
||||
|
||||
try:
|
||||
Model(a='a')
|
||||
except ValidationError as e:
|
||||
print(e.json(indent=2))
|
||||
'''
|
||||
[
|
||||
{
|
||||
"type": "int_parsing",
|
||||
"loc": [
|
||||
"a"
|
||||
],
|
||||
"msg": "Input should be a valid integer, unable to parse string as an integer",
|
||||
"input": "a",
|
||||
"url": "https://errors.pydantic.dev/2/v/int_parsing"
|
||||
}
|
||||
]
|
||||
'''
|
||||
```
|
||||
"""
|
||||
|
||||
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
schema = handler(source_type)
|
||||
info_arg = _inspect_validator(self.func, 'after')
|
||||
if info_arg:
|
||||
func = cast(core_schema.WithInfoValidatorFunction, self.func)
|
||||
return core_schema.with_info_after_validator_function(func, schema=schema, field_name=handler.field_name)
|
||||
else:
|
||||
func = cast(core_schema.NoInfoValidatorFunction, self.func)
|
||||
return core_schema.no_info_after_validator_function(func, schema=schema)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(func=decorator.func)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class BeforeValidator:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/validators/#annotated-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **before** the inner validation logic.
|
||||
|
||||
Attributes:
|
||||
func: The validator function.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
|
||||
JSON Schema (in validation mode).
|
||||
|
||||
Example:
|
||||
```py
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, BeforeValidator
|
||||
|
||||
MyInt = Annotated[int, BeforeValidator(lambda v: v + 1)]
|
||||
|
||||
class Model(BaseModel):
|
||||
a: MyInt
|
||||
|
||||
print(Model(a=1).a)
|
||||
#> 2
|
||||
|
||||
try:
|
||||
Model(a='a')
|
||||
except TypeError as e:
|
||||
print(e)
|
||||
#> can only concatenate str (not "int") to str
|
||||
```
|
||||
"""
|
||||
|
||||
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
|
||||
json_schema_input_type: Any = PydanticUndefined
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
schema = handler(source_type)
|
||||
input_schema = (
|
||||
None
|
||||
if self.json_schema_input_type is PydanticUndefined
|
||||
else handler.generate_schema(self.json_schema_input_type)
|
||||
)
|
||||
metadata = _core_metadata.build_metadata_dict(js_input_core_schema=input_schema)
|
||||
|
||||
info_arg = _inspect_validator(self.func, 'before')
|
||||
if info_arg:
|
||||
func = cast(core_schema.WithInfoValidatorFunction, self.func)
|
||||
return core_schema.with_info_before_validator_function(
|
||||
func,
|
||||
schema=schema,
|
||||
field_name=handler.field_name,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
func = cast(core_schema.NoInfoValidatorFunction, self.func)
|
||||
return core_schema.no_info_before_validator_function(func, schema=schema, metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(
|
||||
func=decorator.func,
|
||||
json_schema_input_type=decorator.info.json_schema_input_type,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class PlainValidator:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/validators/#annotated-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **instead** of the inner validation logic.
|
||||
|
||||
Attributes:
|
||||
func: The validator function.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
|
||||
JSON Schema (in validation mode). If not provided, will default to `Any`.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, PlainValidator
|
||||
|
||||
MyInt = Annotated[int, PlainValidator(lambda v: int(v) + 1)]
|
||||
|
||||
class Model(BaseModel):
|
||||
a: MyInt
|
||||
|
||||
print(Model(a='1').a)
|
||||
#> 2
|
||||
```
|
||||
"""
|
||||
|
||||
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
|
||||
json_schema_input_type: Any = Any
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
# Note that for some valid uses of PlainValidator, it is not possible to generate a core schema for the
|
||||
# source_type, so calling `handler(source_type)` will error, which prevents us from generating a proper
|
||||
# serialization schema. To work around this for use cases that will not involve serialization, we simply
|
||||
# catch any PydanticSchemaGenerationError that may be raised while attempting to build the serialization schema
|
||||
# and abort any attempts to handle special serialization.
|
||||
from pydantic import PydanticSchemaGenerationError
|
||||
|
||||
try:
|
||||
schema = handler(source_type)
|
||||
# TODO if `schema['serialization']` is one of `'include-exclude-dict/sequence',
|
||||
# schema validation will fail. That's why we use 'type ignore' comments below.
|
||||
serialization = schema.get(
|
||||
'serialization',
|
||||
core_schema.wrap_serializer_function_ser_schema(
|
||||
function=lambda v, h: h(v),
|
||||
schema=schema,
|
||||
return_schema=handler.generate_schema(source_type),
|
||||
),
|
||||
)
|
||||
except PydanticSchemaGenerationError:
|
||||
serialization = None
|
||||
|
||||
input_schema = handler.generate_schema(self.json_schema_input_type)
|
||||
metadata = _core_metadata.build_metadata_dict(js_input_core_schema=input_schema)
|
||||
|
||||
info_arg = _inspect_validator(self.func, 'plain')
|
||||
if info_arg:
|
||||
func = cast(core_schema.WithInfoValidatorFunction, self.func)
|
||||
return core_schema.with_info_plain_validator_function(
|
||||
func,
|
||||
field_name=handler.field_name,
|
||||
serialization=serialization, # pyright: ignore[reportArgumentType]
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
func = cast(core_schema.NoInfoValidatorFunction, self.func)
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
func,
|
||||
serialization=serialization, # pyright: ignore[reportArgumentType]
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(
|
||||
func=decorator.func,
|
||||
json_schema_input_type=decorator.info.json_schema_input_type,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class WrapValidator:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/validators/#annotated-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **around** the inner validation logic.
|
||||
|
||||
Attributes:
|
||||
func: The validator function.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
|
||||
JSON Schema (in validation mode).
|
||||
|
||||
```py
|
||||
from datetime import datetime
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, ValidationError, WrapValidator
|
||||
|
||||
def validate_timestamp(v, handler):
|
||||
if v == 'now':
|
||||
# we don't want to bother with further validation, just return the new value
|
||||
return datetime.now()
|
||||
try:
|
||||
return handler(v)
|
||||
except ValidationError:
|
||||
# validation failed, in this case we want to return a default value
|
||||
return datetime(2000, 1, 1)
|
||||
|
||||
MyTimestamp = Annotated[datetime, WrapValidator(validate_timestamp)]
|
||||
|
||||
class Model(BaseModel):
|
||||
a: MyTimestamp
|
||||
|
||||
print(Model(a='now').a)
|
||||
#> 2032-01-02 03:04:05.000006
|
||||
print(Model(a='invalid').a)
|
||||
#> 2000-01-01 00:00:00
|
||||
```
|
||||
"""
|
||||
|
||||
func: core_schema.NoInfoWrapValidatorFunction | core_schema.WithInfoWrapValidatorFunction
|
||||
json_schema_input_type: Any = PydanticUndefined
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
schema = handler(source_type)
|
||||
input_schema = (
|
||||
None
|
||||
if self.json_schema_input_type is PydanticUndefined
|
||||
else handler.generate_schema(self.json_schema_input_type)
|
||||
)
|
||||
metadata = _core_metadata.build_metadata_dict(js_input_core_schema=input_schema)
|
||||
|
||||
info_arg = _inspect_validator(self.func, 'wrap')
|
||||
if info_arg:
|
||||
func = cast(core_schema.WithInfoWrapValidatorFunction, self.func)
|
||||
return core_schema.with_info_wrap_validator_function(
|
||||
func,
|
||||
schema=schema,
|
||||
field_name=handler.field_name,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
func = cast(core_schema.NoInfoWrapValidatorFunction, self.func)
|
||||
return core_schema.no_info_wrap_validator_function(
|
||||
func,
|
||||
schema=schema,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(
|
||||
func=decorator.func,
|
||||
json_schema_input_type=decorator.info.json_schema_input_type,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _OnlyValueValidatorClsMethod(Protocol):
|
||||
def __call__(self, cls: Any, value: Any, /) -> Any: ...
|
||||
|
||||
class _V2ValidatorClsMethod(Protocol):
|
||||
def __call__(self, cls: Any, value: Any, info: _core_schema.ValidationInfo, /) -> Any: ...
|
||||
|
||||
class _OnlyValueWrapValidatorClsMethod(Protocol):
|
||||
def __call__(self, cls: Any, value: Any, handler: _core_schema.ValidatorFunctionWrapHandler, /) -> Any: ...
|
||||
|
||||
class _V2WrapValidatorClsMethod(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
cls: Any,
|
||||
value: Any,
|
||||
handler: _core_schema.ValidatorFunctionWrapHandler,
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> Any: ...
|
||||
|
||||
_V2Validator = Union[
|
||||
_V2ValidatorClsMethod,
|
||||
_core_schema.WithInfoValidatorFunction,
|
||||
_OnlyValueValidatorClsMethod,
|
||||
_core_schema.NoInfoValidatorFunction,
|
||||
]
|
||||
|
||||
_V2WrapValidator = Union[
|
||||
_V2WrapValidatorClsMethod,
|
||||
_core_schema.WithInfoWrapValidatorFunction,
|
||||
_OnlyValueWrapValidatorClsMethod,
|
||||
_core_schema.NoInfoWrapValidatorFunction,
|
||||
]
|
||||
|
||||
_PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]]
|
||||
|
||||
_V2BeforeAfterOrPlainValidatorType = TypeVar(
|
||||
'_V2BeforeAfterOrPlainValidatorType',
|
||||
bound=Union[_V2Validator, _PartialClsOrStaticMethod],
|
||||
)
|
||||
_V2WrapValidatorType = TypeVar('_V2WrapValidatorType', bound=Union[_V2WrapValidator, _PartialClsOrStaticMethod])
|
||||
|
||||
FieldValidatorModes: TypeAlias = Literal['before', 'after', 'wrap', 'plain']
|
||||
|
||||
|
||||
@overload
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['wrap'],
|
||||
check_fields: bool | None = ...,
|
||||
json_schema_input_type: Any = ...,
|
||||
) -> Callable[[_V2WrapValidatorType], _V2WrapValidatorType]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['before', 'plain'],
|
||||
check_fields: bool | None = ...,
|
||||
json_schema_input_type: Any = ...,
|
||||
) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['after'] = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: ...
|
||||
|
||||
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: FieldValidatorModes = 'after',
|
||||
check_fields: bool | None = None,
|
||||
json_schema_input_type: Any = PydanticUndefined,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/validators/#field-validators
|
||||
|
||||
Decorate methods on the class indicating that they should be used to validate fields.
|
||||
|
||||
Example usage:
|
||||
```py
|
||||
from typing import Any
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
class Model(BaseModel):
|
||||
a: str
|
||||
|
||||
@field_validator('a')
|
||||
@classmethod
|
||||
def ensure_foobar(cls, v: Any):
|
||||
if 'foobar' not in v:
|
||||
raise ValueError('"foobar" not found in a')
|
||||
return v
|
||||
|
||||
print(repr(Model(a='this is foobar good')))
|
||||
#> Model(a='this is foobar good')
|
||||
|
||||
try:
|
||||
Model(a='snap')
|
||||
except ValidationError as exc_info:
|
||||
print(exc_info)
|
||||
'''
|
||||
1 validation error for Model
|
||||
a
|
||||
Value error, "foobar" not found in a [type=value_error, input_value='snap', input_type=str]
|
||||
'''
|
||||
```
|
||||
|
||||
For more in depth examples, see [Field Validators](../concepts/validators.md#field-validators).
|
||||
|
||||
Args:
|
||||
field: The first field the `field_validator` should be called on; this is separate
|
||||
from `fields` to ensure an error is raised if you don't pass at least one.
|
||||
*fields: Additional field(s) the `field_validator` should be called on.
|
||||
mode: Specifies whether to validate the fields before or after validation.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate
|
||||
the appropriate JSON Schema (in validation mode) and can only specified
|
||||
when `mode` is either `'before'`, `'plain'` or `'wrap'`.
|
||||
|
||||
Returns:
|
||||
A decorator that can be used to decorate a function to be used as a field_validator.
|
||||
|
||||
Raises:
|
||||
PydanticUserError:
|
||||
- If `@field_validator` is used bare (with no fields).
|
||||
- If the args passed to `@field_validator` as fields are not strings.
|
||||
- If `@field_validator` applied to instance methods.
|
||||
"""
|
||||
if isinstance(field, FunctionType):
|
||||
raise PydanticUserError(
|
||||
'`@field_validator` should be used with fields and keyword arguments, not bare. '
|
||||
"E.g. usage should be `@validator('<field_name>', ...)`",
|
||||
code='validator-no-fields',
|
||||
)
|
||||
|
||||
if mode not in ('before', 'plain', 'wrap') and json_schema_input_type is not PydanticUndefined:
|
||||
raise PydanticUserError(
|
||||
f"`json_schema_input_type` can't be used when mode is set to {mode!r}",
|
||||
code='validator-input-type',
|
||||
)
|
||||
|
||||
if json_schema_input_type is PydanticUndefined and mode == 'plain':
|
||||
json_schema_input_type = Any
|
||||
|
||||
fields = field, *fields
|
||||
if not all(isinstance(field, str) for field in fields):
|
||||
raise PydanticUserError(
|
||||
'`@field_validator` fields should be passed as separate string args. '
|
||||
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`",
|
||||
code='validator-invalid-fields',
|
||||
)
|
||||
|
||||
def dec(
|
||||
f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any],
|
||||
) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
if _decorators.is_instance_method_from_sig(f):
|
||||
raise PydanticUserError(
|
||||
'`@field_validator` cannot be applied to instance methods', code='validator-instance-method'
|
||||
)
|
||||
|
||||
# auto apply the @classmethod decorator
|
||||
f = _decorators.ensure_classmethod_based_on_signature(f)
|
||||
|
||||
dec_info = _decorators.FieldValidatorDecoratorInfo(
|
||||
fields=fields, mode=mode, check_fields=check_fields, json_schema_input_type=json_schema_input_type
|
||||
)
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
_ModelType = TypeVar('_ModelType')
|
||||
_ModelTypeCo = TypeVar('_ModelTypeCo', covariant=True)
|
||||
|
||||
|
||||
class ModelWrapValidatorHandler(_core_schema.ValidatorFunctionWrapHandler, Protocol[_ModelTypeCo]):
|
||||
"""`@model_validator` decorated function handler argument type. This is used when `mode='wrap'`."""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
value: Any,
|
||||
outer_location: str | int | None = None,
|
||||
/,
|
||||
) -> _ModelTypeCo: # pragma: no cover
|
||||
...
|
||||
|
||||
|
||||
class ModelWrapValidatorWithoutInfo(Protocol[_ModelType]):
|
||||
"""A `@model_validator` decorated function signature.
|
||||
This is used when `mode='wrap'` and the function does not have info argument.
|
||||
"""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
cls: type[_ModelType],
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
handler: ModelWrapValidatorHandler[_ModelType],
|
||||
/,
|
||||
) -> _ModelType: ...
|
||||
|
||||
|
||||
class ModelWrapValidator(Protocol[_ModelType]):
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='wrap'`."""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
cls: type[_ModelType],
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
handler: ModelWrapValidatorHandler[_ModelType],
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> _ModelType: ...
|
||||
|
||||
|
||||
class FreeModelBeforeValidatorWithoutInfo(Protocol):
|
||||
"""A `@model_validator` decorated function signature.
|
||||
This is used when `mode='before'` and the function does not have info argument.
|
||||
"""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
/,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class ModelBeforeValidatorWithoutInfo(Protocol):
|
||||
"""A `@model_validator` decorated function signature.
|
||||
This is used when `mode='before'` and the function does not have info argument.
|
||||
"""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
cls: Any,
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
/,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class FreeModelBeforeValidator(Protocol):
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='before'`."""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class ModelBeforeValidator(Protocol):
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='before'`."""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
cls: Any,
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
ModelAfterValidatorWithoutInfo = Callable[[_ModelType], _ModelType]
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='after'` and the function does not
|
||||
have info argument.
|
||||
"""
|
||||
|
||||
ModelAfterValidator = Callable[[_ModelType, _core_schema.ValidationInfo], _ModelType]
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='after'`."""
|
||||
|
||||
_AnyModelWrapValidator = Union[ModelWrapValidator[_ModelType], ModelWrapValidatorWithoutInfo[_ModelType]]
|
||||
_AnyModelBeforeValidator = Union[
|
||||
FreeModelBeforeValidator, ModelBeforeValidator, FreeModelBeforeValidatorWithoutInfo, ModelBeforeValidatorWithoutInfo
|
||||
]
|
||||
_AnyModelAfterValidator = Union[ModelAfterValidator[_ModelType], ModelAfterValidatorWithoutInfo[_ModelType]]
|
||||
|
||||
|
||||
@overload
|
||||
def model_validator(
|
||||
*,
|
||||
mode: Literal['wrap'],
|
||||
) -> Callable[
|
||||
[_AnyModelWrapValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def model_validator(
|
||||
*,
|
||||
mode: Literal['before'],
|
||||
) -> Callable[
|
||||
[_AnyModelBeforeValidator], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def model_validator(
|
||||
*,
|
||||
mode: Literal['after'],
|
||||
) -> Callable[
|
||||
[_AnyModelAfterValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
|
||||
]: ...
|
||||
|
||||
|
||||
def model_validator(
|
||||
*,
|
||||
mode: Literal['wrap', 'before', 'after'],
|
||||
) -> Any:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/validators/#model-validators
|
||||
|
||||
Decorate model methods for validation purposes.
|
||||
|
||||
Example usage:
|
||||
```py
|
||||
from typing_extensions import Self
|
||||
|
||||
from pydantic import BaseModel, ValidationError, model_validator
|
||||
|
||||
class Square(BaseModel):
|
||||
width: float
|
||||
height: float
|
||||
|
||||
@model_validator(mode='after')
|
||||
def verify_square(self) -> Self:
|
||||
if self.width != self.height:
|
||||
raise ValueError('width and height do not match')
|
||||
return self
|
||||
|
||||
s = Square(width=1, height=1)
|
||||
print(repr(s))
|
||||
#> Square(width=1.0, height=1.0)
|
||||
|
||||
try:
|
||||
Square(width=1, height=2)
|
||||
except ValidationError as e:
|
||||
print(e)
|
||||
'''
|
||||
1 validation error for Square
|
||||
Value error, width and height do not match [type=value_error, input_value={'width': 1, 'height': 2}, input_type=dict]
|
||||
'''
|
||||
```
|
||||
|
||||
For more in depth examples, see [Model Validators](../concepts/validators.md#model-validators).
|
||||
|
||||
Args:
|
||||
mode: A required string literal that specifies the validation mode.
|
||||
It can be one of the following: 'wrap', 'before', or 'after'.
|
||||
|
||||
Returns:
|
||||
A decorator that can be used to decorate a function to be used as a model validator.
|
||||
"""
|
||||
|
||||
def dec(f: Any) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
# auto apply the @classmethod decorator
|
||||
f = _decorators.ensure_classmethod_based_on_signature(f)
|
||||
dec_info = _decorators.ModelValidatorDecoratorInfo(mode=mode)
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
AnyType = TypeVar('AnyType')
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# If we add configurable attributes to IsInstance, we'd probably need to stop hiding it from type checkers like this
|
||||
InstanceOf = Annotated[AnyType, ...] # `IsInstance[Sequence]` will be recognized by type checkers as `Sequence`
|
||||
|
||||
else:
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class InstanceOf:
|
||||
'''Generic type for annotating a type that is an instance of a given class.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from pydantic import BaseModel, InstanceOf
|
||||
|
||||
class Foo:
|
||||
...
|
||||
|
||||
class Bar(BaseModel):
|
||||
foo: InstanceOf[Foo]
|
||||
|
||||
Bar(foo=Foo())
|
||||
try:
|
||||
Bar(foo=42)
|
||||
except ValidationError as e:
|
||||
print(e)
|
||||
"""
|
||||
[
|
||||
│ {
|
||||
│ │ 'type': 'is_instance_of',
|
||||
│ │ 'loc': ('foo',),
|
||||
│ │ 'msg': 'Input should be an instance of Foo',
|
||||
│ │ 'input': 42,
|
||||
│ │ 'ctx': {'class': 'Foo'},
|
||||
│ │ 'url': 'https://errors.pydantic.dev/0.38.0/v/is_instance_of'
|
||||
│ }
|
||||
]
|
||||
"""
|
||||
```
|
||||
'''
|
||||
|
||||
@classmethod
|
||||
def __class_getitem__(cls, item: AnyType) -> AnyType:
|
||||
return Annotated[item, cls()]
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
from pydantic import PydanticSchemaGenerationError
|
||||
|
||||
# use the generic _origin_ as the second argument to isinstance when appropriate
|
||||
instance_of_schema = core_schema.is_instance_schema(_generics.get_origin(source) or source)
|
||||
|
||||
try:
|
||||
# Try to generate the "standard" schema, which will be used when loading from JSON
|
||||
original_schema = handler(source)
|
||||
except PydanticSchemaGenerationError:
|
||||
# If that fails, just produce a schema that can validate from python
|
||||
return instance_of_schema
|
||||
else:
|
||||
# Use the "original" approach to serialization
|
||||
instance_of_schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
|
||||
function=lambda v, h: h(v), schema=original_schema
|
||||
)
|
||||
return core_schema.json_or_python_schema(python_schema=instance_of_schema, json_schema=original_schema)
|
||||
|
||||
__hash__ = object.__hash__
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
SkipValidation = Annotated[AnyType, ...] # SkipValidation[list[str]] will be treated by type checkers as list[str]
|
||||
else:
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class SkipValidation:
|
||||
"""If this is applied as an annotation (e.g., via `x: Annotated[int, SkipValidation]`), validation will be
|
||||
skipped. You can also use `SkipValidation[int]` as a shorthand for `Annotated[int, SkipValidation]`.
|
||||
|
||||
This can be useful if you want to use a type annotation for documentation/IDE/type-checking purposes,
|
||||
and know that it is safe to skip validation for one or more of the fields.
|
||||
|
||||
Because this converts the validation schema to `any_schema`, subsequent annotation-applied transformations
|
||||
may not have the expected effects. Therefore, when used, this annotation should generally be the final
|
||||
annotation applied to a type.
|
||||
"""
|
||||
|
||||
def __class_getitem__(cls, item: Any) -> Any:
|
||||
return Annotated[item, SkipValidation()]
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
original_schema = handler(source)
|
||||
metadata = _core_metadata.build_metadata_dict(js_annotation_functions=[lambda _c, h: h(original_schema)])
|
||||
return core_schema.any_schema(
|
||||
metadata=metadata,
|
||||
serialization=core_schema.wrap_serializer_function_ser_schema(
|
||||
function=lambda v, h: h(v), schema=original_schema
|
||||
),
|
||||
)
|
||||
|
||||
__hash__ = object.__hash__
|
||||
5
venv/lib/python3.11/site-packages/pydantic/generics.py
Normal file
5
venv/lib/python3.11/site-packages/pydantic/generics.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""The `generics` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
5
venv/lib/python3.11/site-packages/pydantic/json.py
Normal file
5
venv/lib/python3.11/site-packages/pydantic/json.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""The `json` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
2574
venv/lib/python3.11/site-packages/pydantic/json_schema.py
Normal file
2574
venv/lib/python3.11/site-packages/pydantic/json_schema.py
Normal file
File diff suppressed because it is too large
Load Diff
1610
venv/lib/python3.11/site-packages/pydantic/main.py
Normal file
1610
venv/lib/python3.11/site-packages/pydantic/main.py
Normal file
File diff suppressed because it is too large
Load Diff
1396
venv/lib/python3.11/site-packages/pydantic/mypy.py
Normal file
1396
venv/lib/python3.11/site-packages/pydantic/mypy.py
Normal file
File diff suppressed because it is too large
Load Diff
778
venv/lib/python3.11/site-packages/pydantic/networks.py
Normal file
778
venv/lib/python3.11/site-packages/pydantic/networks.py
Normal file
@@ -0,0 +1,778 @@
|
||||
"""The networks module contains types for common network-related fields."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses as _dataclasses
|
||||
import re
|
||||
from importlib.metadata import version
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic_core import MultiHostUrl, PydanticCustomError, Url, core_schema
|
||||
from typing_extensions import Annotated, Self, TypeAlias
|
||||
|
||||
from ._internal import _fields, _repr, _schema_generation_shared
|
||||
from ._migration import getattr_migration
|
||||
from .annotated_handlers import GetCoreSchemaHandler
|
||||
from .json_schema import JsonSchemaValue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import email_validator
|
||||
|
||||
NetworkType: TypeAlias = 'str | bytes | int | tuple[str | bytes | int, str | int]'
|
||||
|
||||
else:
|
||||
email_validator = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
'AnyUrl',
|
||||
'AnyHttpUrl',
|
||||
'FileUrl',
|
||||
'FtpUrl',
|
||||
'HttpUrl',
|
||||
'WebsocketUrl',
|
||||
'AnyWebsocketUrl',
|
||||
'UrlConstraints',
|
||||
'EmailStr',
|
||||
'NameEmail',
|
||||
'IPvAnyAddress',
|
||||
'IPvAnyInterface',
|
||||
'IPvAnyNetwork',
|
||||
'PostgresDsn',
|
||||
'CockroachDsn',
|
||||
'AmqpDsn',
|
||||
'RedisDsn',
|
||||
'MongoDsn',
|
||||
'KafkaDsn',
|
||||
'NatsDsn',
|
||||
'validate_email',
|
||||
'MySQLDsn',
|
||||
'MariaDBDsn',
|
||||
'ClickHouseDsn',
|
||||
'SnowflakeDsn',
|
||||
]
|
||||
|
||||
|
||||
@_dataclasses.dataclass
|
||||
class UrlConstraints(_fields.PydanticMetadata):
|
||||
"""Url constraints.
|
||||
|
||||
Attributes:
|
||||
max_length: The maximum length of the url. Defaults to `None`.
|
||||
allowed_schemes: The allowed schemes. Defaults to `None`.
|
||||
host_required: Whether the host is required. Defaults to `None`.
|
||||
default_host: The default host. Defaults to `None`.
|
||||
default_port: The default port. Defaults to `None`.
|
||||
default_path: The default path. Defaults to `None`.
|
||||
"""
|
||||
|
||||
max_length: int | None = None
|
||||
allowed_schemes: list[str] | None = None
|
||||
host_required: bool | None = None
|
||||
default_host: str | None = None
|
||||
default_port: int | None = None
|
||||
default_path: str | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.max_length,
|
||||
tuple(self.allowed_schemes) if self.allowed_schemes is not None else None,
|
||||
self.host_required,
|
||||
self.default_host,
|
||||
self.default_port,
|
||||
self.default_path,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
AnyUrl = Url
|
||||
"""Base type for all URLs.
|
||||
|
||||
* Any scheme allowed
|
||||
* Top-level domain (TLD) not required
|
||||
* Host required
|
||||
|
||||
Assuming an input URL of `http://samuel:pass@example.com:8000/the/path/?query=here#fragment=is;this=bit`,
|
||||
the types export the following properties:
|
||||
|
||||
- `scheme`: the URL scheme (`http`), always set.
|
||||
- `host`: the URL host (`example.com`), always set.
|
||||
- `username`: optional username if included (`samuel`).
|
||||
- `password`: optional password if included (`pass`).
|
||||
- `port`: optional port (`8000`).
|
||||
- `path`: optional path (`/the/path/`).
|
||||
- `query`: optional URL query (for example, `GET` arguments or "search string", such as `query=here`).
|
||||
- `fragment`: optional fragment (`fragment=is;this=bit`).
|
||||
"""
|
||||
AnyHttpUrl = Annotated[Url, UrlConstraints(allowed_schemes=['http', 'https'])]
|
||||
"""A type that will accept any http or https URL.
|
||||
|
||||
* TLD not required
|
||||
* Host required
|
||||
"""
|
||||
HttpUrl = Annotated[Url, UrlConstraints(max_length=2083, allowed_schemes=['http', 'https'])]
|
||||
"""A type that will accept any http or https URL.
|
||||
|
||||
* TLD not required
|
||||
* Host required
|
||||
* Max length 2083
|
||||
|
||||
```py
|
||||
from pydantic import BaseModel, HttpUrl, ValidationError
|
||||
|
||||
class MyModel(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
m = MyModel(url='http://www.example.com') # (1)!
|
||||
print(m.url)
|
||||
#> http://www.example.com/
|
||||
|
||||
try:
|
||||
MyModel(url='ftp://invalid.url')
|
||||
except ValidationError as e:
|
||||
print(e)
|
||||
'''
|
||||
1 validation error for MyModel
|
||||
url
|
||||
URL scheme should be 'http' or 'https' [type=url_scheme, input_value='ftp://invalid.url', input_type=str]
|
||||
'''
|
||||
|
||||
try:
|
||||
MyModel(url='not a url')
|
||||
except ValidationError as e:
|
||||
print(e)
|
||||
'''
|
||||
1 validation error for MyModel
|
||||
url
|
||||
Input should be a valid URL, relative URL without a base [type=url_parsing, input_value='not a url', input_type=str]
|
||||
'''
|
||||
```
|
||||
|
||||
1. Note: mypy would prefer `m = MyModel(url=HttpUrl('http://www.example.com'))`, but Pydantic will convert the string to an HttpUrl instance anyway.
|
||||
|
||||
"International domains" (e.g. a URL where the host or TLD includes non-ascii characters) will be encoded via
|
||||
[punycode](https://en.wikipedia.org/wiki/Punycode) (see
|
||||
[this article](https://www.xudongz.com/blog/2017/idn-phishing/) for a good description of why this is important):
|
||||
|
||||
```py
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
class MyModel(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
m1 = MyModel(url='http://puny£code.com')
|
||||
print(m1.url)
|
||||
#> http://xn--punycode-eja.com/
|
||||
m2 = MyModel(url='https://www.аррӏе.com/')
|
||||
print(m2.url)
|
||||
#> https://www.xn--80ak6aa92e.com/
|
||||
m3 = MyModel(url='https://www.example.珠宝/')
|
||||
print(m3.url)
|
||||
#> https://www.example.xn--pbt977c/
|
||||
```
|
||||
|
||||
|
||||
!!! warning "Underscores in Hostnames"
|
||||
In Pydantic, underscores are allowed in all parts of a domain except the TLD.
|
||||
Technically this might be wrong - in theory the hostname cannot have underscores, but subdomains can.
|
||||
|
||||
To explain this; consider the following two cases:
|
||||
|
||||
- `exam_ple.co.uk`: the hostname is `exam_ple`, which should not be allowed since it contains an underscore.
|
||||
- `foo_bar.example.com` the hostname is `example`, which should be allowed since the underscore is in the subdomain.
|
||||
|
||||
Without having an exhaustive list of TLDs, it would be impossible to differentiate between these two. Therefore
|
||||
underscores are allowed, but you can always do further validation in a validator if desired.
|
||||
|
||||
Also, Chrome, Firefox, and Safari all currently accept `http://exam_ple.com` as a URL, so we're in good
|
||||
(or at least big) company.
|
||||
"""
|
||||
AnyWebsocketUrl = Annotated[Url, UrlConstraints(allowed_schemes=['ws', 'wss'])]
|
||||
"""A type that will accept any ws or wss URL.
|
||||
|
||||
* TLD not required
|
||||
* Host required
|
||||
"""
|
||||
WebsocketUrl = Annotated[Url, UrlConstraints(max_length=2083, allowed_schemes=['ws', 'wss'])]
|
||||
"""A type that will accept any ws or wss URL.
|
||||
|
||||
* TLD not required
|
||||
* Host required
|
||||
* Max length 2083
|
||||
"""
|
||||
FileUrl = Annotated[Url, UrlConstraints(allowed_schemes=['file'])]
|
||||
"""A type that will accept any file URL.
|
||||
|
||||
* Host not required
|
||||
"""
|
||||
FtpUrl = Annotated[Url, UrlConstraints(allowed_schemes=['ftp'])]
|
||||
"""A type that will accept ftp URL.
|
||||
|
||||
* TLD not required
|
||||
* Host required
|
||||
"""
|
||||
PostgresDsn = Annotated[
|
||||
MultiHostUrl,
|
||||
UrlConstraints(
|
||||
host_required=True,
|
||||
allowed_schemes=[
|
||||
'postgres',
|
||||
'postgresql',
|
||||
'postgresql+asyncpg',
|
||||
'postgresql+pg8000',
|
||||
'postgresql+psycopg',
|
||||
'postgresql+psycopg2',
|
||||
'postgresql+psycopg2cffi',
|
||||
'postgresql+py-postgresql',
|
||||
'postgresql+pygresql',
|
||||
],
|
||||
),
|
||||
]
|
||||
"""A type that will accept any Postgres DSN.
|
||||
|
||||
* User info required
|
||||
* TLD not required
|
||||
* Host required
|
||||
* Supports multiple hosts
|
||||
|
||||
If further validation is required, these properties can be used by validators to enforce specific behaviour:
|
||||
|
||||
```py
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
HttpUrl,
|
||||
PostgresDsn,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
class MyModel(BaseModel):
|
||||
url: HttpUrl
|
||||
|
||||
m = MyModel(url='http://www.example.com')
|
||||
|
||||
# the repr() method for a url will display all properties of the url
|
||||
print(repr(m.url))
|
||||
#> Url('http://www.example.com/')
|
||||
print(m.url.scheme)
|
||||
#> http
|
||||
print(m.url.host)
|
||||
#> www.example.com
|
||||
print(m.url.port)
|
||||
#> 80
|
||||
|
||||
class MyDatabaseModel(BaseModel):
|
||||
db: PostgresDsn
|
||||
|
||||
@field_validator('db')
|
||||
def check_db_name(cls, v):
|
||||
assert v.path and len(v.path) > 1, 'database must be provided'
|
||||
return v
|
||||
|
||||
m = MyDatabaseModel(db='postgres://user:pass@localhost:5432/foobar')
|
||||
print(m.db)
|
||||
#> postgres://user:pass@localhost:5432/foobar
|
||||
|
||||
try:
|
||||
MyDatabaseModel(db='postgres://user:pass@localhost:5432')
|
||||
except ValidationError as e:
|
||||
print(e)
|
||||
'''
|
||||
1 validation error for MyDatabaseModel
|
||||
db
|
||||
Assertion failed, database must be provided
|
||||
assert (None)
|
||||
+ where None = MultiHostUrl('postgres://user:pass@localhost:5432').path [type=assertion_error, input_value='postgres://user:pass@localhost:5432', input_type=str]
|
||||
'''
|
||||
```
|
||||
"""
|
||||
|
||||
CockroachDsn = Annotated[
|
||||
Url,
|
||||
UrlConstraints(
|
||||
host_required=True,
|
||||
allowed_schemes=[
|
||||
'cockroachdb',
|
||||
'cockroachdb+psycopg2',
|
||||
'cockroachdb+asyncpg',
|
||||
],
|
||||
),
|
||||
]
|
||||
"""A type that will accept any Cockroach DSN.
|
||||
|
||||
* User info required
|
||||
* TLD not required
|
||||
* Host required
|
||||
"""
|
||||
AmqpDsn = Annotated[Url, UrlConstraints(allowed_schemes=['amqp', 'amqps'])]
|
||||
"""A type that will accept any AMQP DSN.
|
||||
|
||||
* User info required
|
||||
* TLD not required
|
||||
* Host required
|
||||
"""
|
||||
RedisDsn = Annotated[
|
||||
Url,
|
||||
UrlConstraints(allowed_schemes=['redis', 'rediss'], default_host='localhost', default_port=6379, default_path='/0'),
|
||||
]
|
||||
"""A type that will accept any Redis DSN.
|
||||
|
||||
* User info required
|
||||
* TLD not required
|
||||
* Host required (e.g., `rediss://:pass@localhost`)
|
||||
"""
|
||||
MongoDsn = Annotated[MultiHostUrl, UrlConstraints(allowed_schemes=['mongodb', 'mongodb+srv'], default_port=27017)]
|
||||
"""A type that will accept any MongoDB DSN.
|
||||
|
||||
* User info not required
|
||||
* Database name not required
|
||||
* Port not required
|
||||
* User info may be passed without user part (e.g., `mongodb://mongodb0.example.com:27017`).
|
||||
"""
|
||||
KafkaDsn = Annotated[Url, UrlConstraints(allowed_schemes=['kafka'], default_host='localhost', default_port=9092)]
|
||||
"""A type that will accept any Kafka DSN.
|
||||
|
||||
* User info required
|
||||
* TLD not required
|
||||
* Host required
|
||||
"""
|
||||
NatsDsn = Annotated[
|
||||
MultiHostUrl,
|
||||
UrlConstraints(allowed_schemes=['nats', 'tls', 'ws', 'wss'], default_host='localhost', default_port=4222),
|
||||
]
|
||||
"""A type that will accept any NATS DSN.
|
||||
|
||||
NATS is a connective technology built for the ever increasingly hyper-connected world.
|
||||
It is a single technology that enables applications to securely communicate across
|
||||
any combination of cloud vendors, on-premise, edge, web and mobile, and devices.
|
||||
More: https://nats.io
|
||||
"""
|
||||
MySQLDsn = Annotated[
|
||||
Url,
|
||||
UrlConstraints(
|
||||
allowed_schemes=[
|
||||
'mysql',
|
||||
'mysql+mysqlconnector',
|
||||
'mysql+aiomysql',
|
||||
'mysql+asyncmy',
|
||||
'mysql+mysqldb',
|
||||
'mysql+pymysql',
|
||||
'mysql+cymysql',
|
||||
'mysql+pyodbc',
|
||||
],
|
||||
default_port=3306,
|
||||
),
|
||||
]
|
||||
"""A type that will accept any MySQL DSN.
|
||||
|
||||
* User info required
|
||||
* TLD not required
|
||||
* Host required
|
||||
"""
|
||||
MariaDBDsn = Annotated[
|
||||
Url,
|
||||
UrlConstraints(
|
||||
allowed_schemes=['mariadb', 'mariadb+mariadbconnector', 'mariadb+pymysql'],
|
||||
default_port=3306,
|
||||
),
|
||||
]
|
||||
"""A type that will accept any MariaDB DSN.
|
||||
|
||||
* User info required
|
||||
* TLD not required
|
||||
* Host required
|
||||
"""
|
||||
ClickHouseDsn = Annotated[
|
||||
Url,
|
||||
UrlConstraints(
|
||||
allowed_schemes=['clickhouse+native', 'clickhouse+asynch'],
|
||||
default_host='localhost',
|
||||
default_port=9000,
|
||||
),
|
||||
]
|
||||
"""A type that will accept any ClickHouse DSN.
|
||||
|
||||
* User info required
|
||||
* TLD not required
|
||||
* Host required
|
||||
"""
|
||||
SnowflakeDsn = Annotated[
|
||||
Url,
|
||||
UrlConstraints(
|
||||
allowed_schemes=['snowflake'],
|
||||
host_required=True,
|
||||
),
|
||||
]
|
||||
"""A type that will accept any Snowflake DSN.
|
||||
|
||||
* User info required
|
||||
* TLD not required
|
||||
* Host required
|
||||
"""
|
||||
|
||||
|
||||
def import_email_validator() -> None:
|
||||
global email_validator
|
||||
try:
|
||||
import email_validator
|
||||
except ImportError as e:
|
||||
raise ImportError('email-validator is not installed, run `pip install pydantic[email]`') from e
|
||||
if not version('email-validator').partition('.')[0] == '2':
|
||||
raise ImportError('email-validator version >= 2.0 required, run pip install -U email-validator')
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
EmailStr = Annotated[str, ...]
|
||||
else:
|
||||
|
||||
class EmailStr:
|
||||
"""
|
||||
Info:
|
||||
To use this type, you need to install the optional
|
||||
[`email-validator`](https://github.com/JoshData/python-email-validator) package:
|
||||
|
||||
```bash
|
||||
pip install email-validator
|
||||
```
|
||||
|
||||
Validate email addresses.
|
||||
|
||||
```py
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
class Model(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
print(Model(email='contact@mail.com'))
|
||||
#> email='contact@mail.com'
|
||||
```
|
||||
""" # noqa: D212
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source: type[Any],
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> core_schema.CoreSchema:
|
||||
import_email_validator()
|
||||
return core_schema.no_info_after_validator_function(cls._validate, core_schema.str_schema())
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
field_schema = handler(core_schema)
|
||||
field_schema.update(type='string', format='email')
|
||||
return field_schema
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, input_value: str, /) -> str:
|
||||
return validate_email(input_value)[1]
|
||||
|
||||
|
||||
class NameEmail(_repr.Representation):
|
||||
"""
|
||||
Info:
|
||||
To use this type, you need to install the optional
|
||||
[`email-validator`](https://github.com/JoshData/python-email-validator) package:
|
||||
|
||||
```bash
|
||||
pip install email-validator
|
||||
```
|
||||
|
||||
Validate a name and email address combination, as specified by
|
||||
[RFC 5322](https://datatracker.ietf.org/doc/html/rfc5322#section-3.4).
|
||||
|
||||
The `NameEmail` has two properties: `name` and `email`.
|
||||
In case the `name` is not provided, it's inferred from the email address.
|
||||
|
||||
```py
|
||||
from pydantic import BaseModel, NameEmail
|
||||
|
||||
class User(BaseModel):
|
||||
email: NameEmail
|
||||
|
||||
user = User(email='Fred Bloggs <fred.bloggs@example.com>')
|
||||
print(user.email)
|
||||
#> Fred Bloggs <fred.bloggs@example.com>
|
||||
print(user.email.name)
|
||||
#> Fred Bloggs
|
||||
|
||||
user = User(email='fred.bloggs@example.com')
|
||||
print(user.email)
|
||||
#> fred.bloggs <fred.bloggs@example.com>
|
||||
print(user.email.name)
|
||||
#> fred.bloggs
|
||||
```
|
||||
""" # noqa: D212
|
||||
|
||||
__slots__ = 'name', 'email'
|
||||
|
||||
def __init__(self, name: str, email: str):
|
||||
self.name = name
|
||||
self.email = email
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, NameEmail) and (self.name, self.email) == (other.name, other.email)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
field_schema = handler(core_schema)
|
||||
field_schema.update(type='string', format='name-email')
|
||||
return field_schema
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source: type[Any],
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> core_schema.CoreSchema:
|
||||
import_email_validator()
|
||||
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls._validate,
|
||||
core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.str_schema(),
|
||||
python_schema=core_schema.union_schema(
|
||||
[core_schema.is_instance_schema(cls), core_schema.str_schema()],
|
||||
custom_error_type='name_email_type',
|
||||
custom_error_message='Input is not a valid NameEmail',
|
||||
),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, input_value: Self | str, /) -> Self:
|
||||
if isinstance(input_value, str):
|
||||
name, email = validate_email(input_value)
|
||||
return cls(name, email)
|
||||
else:
|
||||
return input_value
|
||||
|
||||
def __str__(self) -> str:
|
||||
if '@' in self.name:
|
||||
return f'"{self.name}" <{self.email}>'
|
||||
|
||||
return f'{self.name} <{self.email}>'
|
||||
|
||||
|
||||
IPvAnyAddressType: TypeAlias = 'IPv4Address | IPv6Address'
|
||||
IPvAnyInterfaceType: TypeAlias = 'IPv4Interface | IPv6Interface'
|
||||
IPvAnyNetworkType: TypeAlias = 'IPv4Network | IPv6Network'
|
||||
|
||||
if TYPE_CHECKING:
|
||||
IPvAnyAddress = IPvAnyAddressType
|
||||
IPvAnyInterface = IPvAnyInterfaceType
|
||||
IPvAnyNetwork = IPvAnyNetworkType
|
||||
else:
|
||||
|
||||
class IPvAnyAddress:
|
||||
"""Validate an IPv4 or IPv6 address.
|
||||
|
||||
```py
|
||||
from pydantic import BaseModel
|
||||
from pydantic.networks import IPvAnyAddress
|
||||
|
||||
class IpModel(BaseModel):
|
||||
ip: IPvAnyAddress
|
||||
|
||||
print(IpModel(ip='127.0.0.1'))
|
||||
#> ip=IPv4Address('127.0.0.1')
|
||||
|
||||
try:
|
||||
IpModel(ip='http://www.example.com')
|
||||
except ValueError as e:
|
||||
print(e.errors())
|
||||
'''
|
||||
[
|
||||
{
|
||||
'type': 'ip_any_address',
|
||||
'loc': ('ip',),
|
||||
'msg': 'value is not a valid IPv4 or IPv6 address',
|
||||
'input': 'http://www.example.com',
|
||||
}
|
||||
]
|
||||
'''
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, value: Any) -> IPvAnyAddressType:
|
||||
"""Validate an IPv4 or IPv6 address."""
|
||||
try:
|
||||
return IPv4Address(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Address(value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_any_address', 'value is not a valid IPv4 or IPv6 address')
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
field_schema = {}
|
||||
field_schema.update(type='string', format='ipvanyaddress')
|
||||
return field_schema
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source: type[Any],
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> core_schema.CoreSchema:
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
cls._validate, serialization=core_schema.to_string_ser_schema()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, input_value: Any, /) -> IPvAnyAddressType:
|
||||
return cls(input_value) # type: ignore[return-value]
|
||||
|
||||
class IPvAnyInterface:
|
||||
"""Validate an IPv4 or IPv6 interface."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, value: NetworkType) -> IPvAnyInterfaceType:
|
||||
"""Validate an IPv4 or IPv6 interface."""
|
||||
try:
|
||||
return IPv4Interface(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Interface(value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_any_interface', 'value is not a valid IPv4 or IPv6 interface')
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
field_schema = {}
|
||||
field_schema.update(type='string', format='ipvanyinterface')
|
||||
return field_schema
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source: type[Any],
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> core_schema.CoreSchema:
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
cls._validate, serialization=core_schema.to_string_ser_schema()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, input_value: NetworkType, /) -> IPvAnyInterfaceType:
|
||||
return cls(input_value) # type: ignore[return-value]
|
||||
|
||||
class IPvAnyNetwork:
|
||||
"""Validate an IPv4 or IPv6 network."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, value: NetworkType) -> IPvAnyNetworkType:
|
||||
"""Validate an IPv4 or IPv6 network."""
|
||||
# Assume IP Network is defined with a default value for `strict` argument.
|
||||
# Define your own class if you want to specify network address check strictness.
|
||||
try:
|
||||
return IPv4Network(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Network(value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_any_network', 'value is not a valid IPv4 or IPv6 network')
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
field_schema = {}
|
||||
field_schema.update(type='string', format='ipvanynetwork')
|
||||
return field_schema
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source: type[Any],
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> core_schema.CoreSchema:
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
cls._validate, serialization=core_schema.to_string_ser_schema()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, input_value: NetworkType, /) -> IPvAnyNetworkType:
|
||||
return cls(input_value) # type: ignore[return-value]
|
||||
|
||||
|
||||
def _build_pretty_email_regex() -> re.Pattern[str]:
|
||||
name_chars = r'[\w!#$%&\'*+\-/=?^_`{|}~]'
|
||||
unquoted_name_group = rf'((?:{name_chars}+\s+)*{name_chars}+)'
|
||||
quoted_name_group = r'"((?:[^"]|\")+)"'
|
||||
email_group = r'<\s*(.+)\s*>'
|
||||
return re.compile(rf'\s*(?:{unquoted_name_group}|{quoted_name_group})?\s*{email_group}\s*')
|
||||
|
||||
|
||||
pretty_email_regex = _build_pretty_email_regex()
|
||||
|
||||
MAX_EMAIL_LENGTH = 2048
|
||||
"""Maximum length for an email.
|
||||
A somewhat arbitrary but very generous number compared to what is allowed by most implementations.
|
||||
"""
|
||||
|
||||
|
||||
def validate_email(value: str) -> tuple[str, str]:
|
||||
"""Email address validation using [email-validator](https://pypi.org/project/email-validator/).
|
||||
|
||||
Note:
|
||||
Note that:
|
||||
|
||||
* Raw IP address (literal) domain parts are not allowed.
|
||||
* `"John Doe <local_part@domain.com>"` style "pretty" email addresses are processed.
|
||||
* Spaces are striped from the beginning and end of addresses, but no error is raised.
|
||||
"""
|
||||
if email_validator is None:
|
||||
import_email_validator()
|
||||
|
||||
if len(value) > MAX_EMAIL_LENGTH:
|
||||
raise PydanticCustomError(
|
||||
'value_error',
|
||||
'value is not a valid email address: {reason}',
|
||||
{'reason': f'Length must not exceed {MAX_EMAIL_LENGTH} characters'},
|
||||
)
|
||||
|
||||
m = pretty_email_regex.fullmatch(value)
|
||||
name: str | None = None
|
||||
if m:
|
||||
unquoted_name, quoted_name, value = m.groups()
|
||||
name = unquoted_name or quoted_name
|
||||
|
||||
email = value.strip()
|
||||
|
||||
try:
|
||||
parts = email_validator.validate_email(email, check_deliverability=False)
|
||||
except email_validator.EmailNotValidError as e:
|
||||
raise PydanticCustomError(
|
||||
'value_error', 'value is not a valid email address: {reason}', {'reason': str(e.args[0])}
|
||||
) from e
|
||||
|
||||
email = parts.normalized
|
||||
assert email is not None
|
||||
name = name or parts.local_part
|
||||
return name, email
|
||||
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
5
venv/lib/python3.11/site-packages/pydantic/parse.py
Normal file
5
venv/lib/python3.11/site-packages/pydantic/parse.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""The `parse` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
171
venv/lib/python3.11/site-packages/pydantic/plugin/__init__.py
Normal file
171
venv/lib/python3.11/site-packages/pydantic/plugin/__init__.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/plugins#build-a-plugin
|
||||
|
||||
Plugin interface for Pydantic plugins, and related types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, NamedTuple
|
||||
|
||||
from pydantic_core import CoreConfig, CoreSchema, ValidationError
|
||||
from typing_extensions import Literal, Protocol, TypeAlias
|
||||
|
||||
__all__ = (
|
||||
'PydanticPluginProtocol',
|
||||
'BaseValidateHandlerProtocol',
|
||||
'ValidatePythonHandlerProtocol',
|
||||
'ValidateJsonHandlerProtocol',
|
||||
'ValidateStringsHandlerProtocol',
|
||||
'NewSchemaReturns',
|
||||
'SchemaTypePath',
|
||||
'SchemaKind',
|
||||
)
|
||||
|
||||
NewSchemaReturns: TypeAlias = 'tuple[ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None]'
|
||||
|
||||
|
||||
class SchemaTypePath(NamedTuple):
|
||||
"""Path defining where `schema_type` was defined, or where `TypeAdapter` was called."""
|
||||
|
||||
module: str
|
||||
name: str
|
||||
|
||||
|
||||
SchemaKind: TypeAlias = Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call']
|
||||
|
||||
|
||||
class PydanticPluginProtocol(Protocol):
|
||||
"""Protocol defining the interface for Pydantic plugins."""
|
||||
|
||||
def new_schema_validator(
|
||||
self,
|
||||
schema: CoreSchema,
|
||||
schema_type: Any,
|
||||
schema_type_path: SchemaTypePath,
|
||||
schema_kind: SchemaKind,
|
||||
config: CoreConfig | None,
|
||||
plugin_settings: dict[str, object],
|
||||
) -> tuple[
|
||||
ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None
|
||||
]:
|
||||
"""This method is called for each plugin every time a new [`SchemaValidator`][pydantic_core.SchemaValidator]
|
||||
is created.
|
||||
|
||||
It should return an event handler for each of the three validation methods, or `None` if the plugin does not
|
||||
implement that method.
|
||||
|
||||
Args:
|
||||
schema: The schema to validate against.
|
||||
schema_type: The original type which the schema was created from, e.g. the model class.
|
||||
schema_type_path: Path defining where `schema_type` was defined, or where `TypeAdapter` was called.
|
||||
schema_kind: The kind of schema to validate against.
|
||||
config: The config to use for validation.
|
||||
plugin_settings: Any plugin settings.
|
||||
|
||||
Returns:
|
||||
A tuple of optional event handlers for each of the three validation methods -
|
||||
`validate_python`, `validate_json`, `validate_strings`.
|
||||
"""
|
||||
raise NotImplementedError('Pydantic plugins should implement `new_schema_validator`.')
|
||||
|
||||
|
||||
class BaseValidateHandlerProtocol(Protocol):
|
||||
"""Base class for plugin callbacks protocols.
|
||||
|
||||
You shouldn't implement this protocol directly, instead use one of the subclasses with adds the correctly
|
||||
typed `on_error` method.
|
||||
"""
|
||||
|
||||
on_enter: Callable[..., None]
|
||||
"""`on_enter` is changed to be more specific on all subclasses"""
|
||||
|
||||
def on_success(self, result: Any) -> None:
|
||||
"""Callback to be notified of successful validation.
|
||||
|
||||
Args:
|
||||
result: The result of the validation.
|
||||
"""
|
||||
return
|
||||
|
||||
def on_error(self, error: ValidationError) -> None:
|
||||
"""Callback to be notified of validation errors.
|
||||
|
||||
Args:
|
||||
error: The validation error.
|
||||
"""
|
||||
return
|
||||
|
||||
def on_exception(self, exception: Exception) -> None:
|
||||
"""Callback to be notified of validation exceptions.
|
||||
|
||||
Args:
|
||||
exception: The exception raised during validation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
"""Event handler for `SchemaValidator.validate_python`."""
|
||||
|
||||
def on_enter(
|
||||
self,
|
||||
input: Any,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
from_attributes: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
self_instance: Any | None = None,
|
||||
) -> None:
|
||||
"""Callback to be notified of validation start, and create an instance of the event handler.
|
||||
|
||||
Args:
|
||||
input: The input to be validated.
|
||||
strict: Whether to validate the object in strict mode.
|
||||
from_attributes: Whether to validate objects as inputs by extracting attributes.
|
||||
context: The context to use for validation, this is passed to functional validators.
|
||||
self_instance: An instance of a model to set attributes on from validation, this is used when running
|
||||
validation from the `__init__` method of a model.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidateJsonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
"""Event handler for `SchemaValidator.validate_json`."""
|
||||
|
||||
def on_enter(
|
||||
self,
|
||||
input: str | bytes | bytearray,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
self_instance: Any | None = None,
|
||||
) -> None:
|
||||
"""Callback to be notified of validation start, and create an instance of the event handler.
|
||||
|
||||
Args:
|
||||
input: The JSON data to be validated.
|
||||
strict: Whether to validate the object in strict mode.
|
||||
context: The context to use for validation, this is passed to functional validators.
|
||||
self_instance: An instance of a model to set attributes on from validation, this is used when running
|
||||
validation from the `__init__` method of a model.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
StringInput: TypeAlias = 'dict[str, StringInput]'
|
||||
|
||||
|
||||
class ValidateStringsHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
"""Event handler for `SchemaValidator.validate_strings`."""
|
||||
|
||||
def on_enter(
|
||||
self, input: StringInput, *, strict: bool | None = None, context: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Callback to be notified of validation start, and create an instance of the event handler.
|
||||
|
||||
Args:
|
||||
input: The string data to be validated.
|
||||
strict: Whether to validate the object in strict mode.
|
||||
context: The context to use for validation, this is passed to functional validators.
|
||||
"""
|
||||
pass
|
||||
56
venv/lib/python3.11/site-packages/pydantic/plugin/_loader.py
Normal file
56
venv/lib/python3.11/site-packages/pydantic/plugin/_loader.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata as importlib_metadata
|
||||
import os
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Final, Iterable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import PydanticPluginProtocol
|
||||
|
||||
|
||||
PYDANTIC_ENTRY_POINT_GROUP: Final[str] = 'pydantic'
|
||||
|
||||
# cache of plugins
|
||||
_plugins: dict[str, PydanticPluginProtocol] | None = None
|
||||
# return no plugins while loading plugins to avoid recursion and errors while import plugins
|
||||
# this means that if plugins use pydantic
|
||||
_loading_plugins: bool = False
|
||||
|
||||
|
||||
def get_plugins() -> Iterable[PydanticPluginProtocol]:
|
||||
"""Load plugins for Pydantic.
|
||||
|
||||
Inspired by: https://github.com/pytest-dev/pluggy/blob/1.3.0/src/pluggy/_manager.py#L376-L402
|
||||
"""
|
||||
disabled_plugins = os.getenv('PYDANTIC_DISABLE_PLUGINS')
|
||||
global _plugins, _loading_plugins
|
||||
if _loading_plugins:
|
||||
# this happens when plugins themselves use pydantic, we return no plugins
|
||||
return ()
|
||||
elif disabled_plugins in ('__all__', '1', 'true'):
|
||||
return ()
|
||||
elif _plugins is None:
|
||||
_plugins = {}
|
||||
# set _loading_plugins so any plugins that use pydantic don't themselves use plugins
|
||||
_loading_plugins = True
|
||||
try:
|
||||
for dist in importlib_metadata.distributions():
|
||||
for entry_point in dist.entry_points:
|
||||
if entry_point.group != PYDANTIC_ENTRY_POINT_GROUP:
|
||||
continue
|
||||
if entry_point.value in _plugins:
|
||||
continue
|
||||
if disabled_plugins is not None and entry_point.name in disabled_plugins.split(','):
|
||||
continue
|
||||
try:
|
||||
_plugins[entry_point.value] = entry_point.load()
|
||||
except (ImportError, AttributeError) as e:
|
||||
warnings.warn(
|
||||
f'{e.__class__.__name__} while loading the `{entry_point.name}` Pydantic plugin, '
|
||||
f'this plugin will not be installed.\n\n{e!r}'
|
||||
)
|
||||
finally:
|
||||
_loading_plugins = False
|
||||
|
||||
return _plugins.values()
|
||||
@@ -0,0 +1,139 @@
|
||||
"""Pluggable schema validator for pydantic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar
|
||||
|
||||
from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError
|
||||
from typing_extensions import Literal, ParamSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath
|
||||
|
||||
|
||||
P = ParamSpec('P')
|
||||
R = TypeVar('R')
|
||||
Event = Literal['on_validate_python', 'on_validate_json', 'on_validate_strings']
|
||||
events: list[Event] = list(Event.__args__) # type: ignore
|
||||
|
||||
|
||||
def create_schema_validator(
|
||||
schema: CoreSchema,
|
||||
schema_type: Any,
|
||||
schema_type_module: str,
|
||||
schema_type_name: str,
|
||||
schema_kind: SchemaKind,
|
||||
config: CoreConfig | None = None,
|
||||
plugin_settings: dict[str, Any] | None = None,
|
||||
) -> SchemaValidator | PluggableSchemaValidator:
|
||||
"""Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.
|
||||
|
||||
Returns:
|
||||
If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`.
|
||||
"""
|
||||
from . import SchemaTypePath
|
||||
from ._loader import get_plugins
|
||||
|
||||
plugins = get_plugins()
|
||||
if plugins:
|
||||
return PluggableSchemaValidator(
|
||||
schema,
|
||||
schema_type,
|
||||
SchemaTypePath(schema_type_module, schema_type_name),
|
||||
schema_kind,
|
||||
config,
|
||||
plugins,
|
||||
plugin_settings or {},
|
||||
)
|
||||
else:
|
||||
return SchemaValidator(schema, config)
|
||||
|
||||
|
||||
class PluggableSchemaValidator:
|
||||
"""Pluggable schema validator."""
|
||||
|
||||
__slots__ = '_schema_validator', 'validate_json', 'validate_python', 'validate_strings'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: CoreSchema,
|
||||
schema_type: Any,
|
||||
schema_type_path: SchemaTypePath,
|
||||
schema_kind: SchemaKind,
|
||||
config: CoreConfig | None,
|
||||
plugins: Iterable[PydanticPluginProtocol],
|
||||
plugin_settings: dict[str, Any],
|
||||
) -> None:
|
||||
self._schema_validator = SchemaValidator(schema, config)
|
||||
|
||||
python_event_handlers: list[BaseValidateHandlerProtocol] = []
|
||||
json_event_handlers: list[BaseValidateHandlerProtocol] = []
|
||||
strings_event_handlers: list[BaseValidateHandlerProtocol] = []
|
||||
for plugin in plugins:
|
||||
try:
|
||||
p, j, s = plugin.new_schema_validator(
|
||||
schema, schema_type, schema_type_path, schema_kind, config, plugin_settings
|
||||
)
|
||||
except TypeError as e: # pragma: no cover
|
||||
raise TypeError(f'Error using plugin `{plugin.__module__}:{plugin.__class__.__name__}`: {e}') from e
|
||||
if p is not None:
|
||||
python_event_handlers.append(p)
|
||||
if j is not None:
|
||||
json_event_handlers.append(j)
|
||||
if s is not None:
|
||||
strings_event_handlers.append(s)
|
||||
|
||||
self.validate_python = build_wrapper(self._schema_validator.validate_python, python_event_handlers)
|
||||
self.validate_json = build_wrapper(self._schema_validator.validate_json, json_event_handlers)
|
||||
self.validate_strings = build_wrapper(self._schema_validator.validate_strings, strings_event_handlers)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._schema_validator, name)
|
||||
|
||||
|
||||
def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandlerProtocol]) -> Callable[P, R]:
|
||||
if not event_handlers:
|
||||
return func
|
||||
else:
|
||||
on_enters = tuple(h.on_enter for h in event_handlers if filter_handlers(h, 'on_enter'))
|
||||
on_successes = tuple(h.on_success for h in event_handlers if filter_handlers(h, 'on_success'))
|
||||
on_errors = tuple(h.on_error for h in event_handlers if filter_handlers(h, 'on_error'))
|
||||
on_exceptions = tuple(h.on_exception for h in event_handlers if filter_handlers(h, 'on_exception'))
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
for on_enter_handler in on_enters:
|
||||
on_enter_handler(*args, **kwargs)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
except ValidationError as error:
|
||||
for on_error_handler in on_errors:
|
||||
on_error_handler(error)
|
||||
raise
|
||||
except Exception as exception:
|
||||
for on_exception_handler in on_exceptions:
|
||||
on_exception_handler(exception)
|
||||
raise
|
||||
else:
|
||||
for on_success_handler in on_successes:
|
||||
on_success_handler(result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def filter_handlers(handler_cls: BaseValidateHandlerProtocol, method_name: str) -> bool:
|
||||
"""Filter out handler methods which are not implemented by the plugin directly - e.g. are missing
|
||||
or are inherited from the protocol.
|
||||
"""
|
||||
handler = getattr(handler_cls, method_name, None)
|
||||
if handler is None:
|
||||
return False
|
||||
elif handler.__module__ == 'pydantic.plugin':
|
||||
# this is the original handler, from the protocol due to runtime inheritance
|
||||
# we don't want to call it
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
0
venv/lib/python3.11/site-packages/pydantic/py.typed
Normal file
0
venv/lib/python3.11/site-packages/pydantic/py.typed
Normal file
154
venv/lib/python3.11/site-packages/pydantic/root_model.py
Normal file
154
venv/lib/python3.11/site-packages/pydantic/root_model.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""RootModel class and type definitions."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import typing
|
||||
from copy import copy, deepcopy
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from . import PydanticUserError
|
||||
from ._internal import _model_construction, _repr
|
||||
from .main import BaseModel, _object_setattr
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Literal, Self, dataclass_transform
|
||||
|
||||
from .fields import Field as PydanticModelField
|
||||
from .fields import PrivateAttr as PydanticModelPrivateAttr
|
||||
|
||||
# dataclass_transform could be applied to RootModel directly, but `ModelMetaclass`'s dataclass_transform
|
||||
# takes priority (at least with pyright). We trick type checkers into thinking we apply dataclass_transform
|
||||
# on a new metaclass.
|
||||
@dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr))
|
||||
class _RootModelMetaclass(_model_construction.ModelMetaclass): ...
|
||||
else:
|
||||
_RootModelMetaclass = _model_construction.ModelMetaclass
|
||||
|
||||
__all__ = ('RootModel',)
|
||||
|
||||
RootModelRootType = typing.TypeVar('RootModelRootType')
|
||||
|
||||
|
||||
class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootModelMetaclass):
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/models/#rootmodel-and-custom-root-types
|
||||
|
||||
A Pydantic `BaseModel` for the root object of the model.
|
||||
|
||||
Attributes:
|
||||
root: The root object of the model.
|
||||
__pydantic_root_model__: Whether the model is a RootModel.
|
||||
__pydantic_private__: Private fields in the model.
|
||||
__pydantic_extra__: Extra fields in the model.
|
||||
|
||||
"""
|
||||
|
||||
__pydantic_root_model__ = True
|
||||
__pydantic_private__ = None
|
||||
__pydantic_extra__ = None
|
||||
|
||||
root: RootModelRootType
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
extra = cls.model_config.get('extra')
|
||||
if extra is not None:
|
||||
raise PydanticUserError(
|
||||
"`RootModel` does not support setting `model_config['extra']`", code='root-model-extra'
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
def __init__(self, /, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore
|
||||
__tracebackhide__ = True
|
||||
if data:
|
||||
if root is not PydanticUndefined:
|
||||
raise ValueError(
|
||||
'"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments'
|
||||
)
|
||||
root = data # type: ignore
|
||||
self.__pydantic_validator__.validate_python(root, self_instance=self)
|
||||
|
||||
__init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess]
|
||||
|
||||
@classmethod
|
||||
def model_construct(cls, root: RootModelRootType, _fields_set: set[str] | None = None) -> Self: # type: ignore
|
||||
"""Create a new model using the provided root object and update fields set.
|
||||
|
||||
Args:
|
||||
root: The root object of the model.
|
||||
_fields_set: The set of fields to be updated.
|
||||
|
||||
Returns:
|
||||
The new model.
|
||||
|
||||
Raises:
|
||||
NotImplemented: If the model is not a subclass of `RootModel`.
|
||||
"""
|
||||
return super().model_construct(root=root, _fields_set=_fields_set)
|
||||
|
||||
def __getstate__(self) -> dict[Any, Any]:
|
||||
return {
|
||||
'__dict__': self.__dict__,
|
||||
'__pydantic_fields_set__': self.__pydantic_fields_set__,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[Any, Any]) -> None:
|
||||
_object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__'])
|
||||
_object_setattr(self, '__dict__', state['__dict__'])
|
||||
|
||||
def __copy__(self) -> Self:
|
||||
"""Returns a shallow copy of the model."""
|
||||
cls = type(self)
|
||||
m = cls.__new__(cls)
|
||||
_object_setattr(m, '__dict__', copy(self.__dict__))
|
||||
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
|
||||
return m
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
|
||||
"""Returns a deep copy of the model."""
|
||||
cls = type(self)
|
||||
m = cls.__new__(cls)
|
||||
_object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo))
|
||||
# This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str],
|
||||
# and attempting a deepcopy would be marginally slower.
|
||||
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
|
||||
return m
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def model_dump( # type: ignore
|
||||
self,
|
||||
*,
|
||||
mode: Literal['json', 'python'] | str = 'python',
|
||||
include: Any = None,
|
||||
exclude: Any = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal['none', 'warn', 'error'] = True,
|
||||
serialize_as_any: bool = False,
|
||||
) -> Any:
|
||||
"""This method is included just to get a more accurate return type for type checkers.
|
||||
It is included in this `if TYPE_CHECKING:` block since no override is actually necessary.
|
||||
|
||||
See the documentation of `BaseModel.model_dump` for more details about the arguments.
|
||||
|
||||
Generally, this method will have a return type of `RootModelRootType`, assuming that `RootModelRootType` is
|
||||
not a `BaseModel` subclass. If `RootModelRootType` is a `BaseModel` subclass, then the return
|
||||
type will likely be `dict[str, Any]`, as `model_dump` calls are recursive. The return type could
|
||||
even be something different, in the case of a custom serializer.
|
||||
Thus, `Any` is used here to catch all of these cases.
|
||||
"""
|
||||
...
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, RootModel):
|
||||
return NotImplemented
|
||||
return self.model_fields['root'].annotation == other.model_fields['root'].annotation and super().__eq__(other)
|
||||
|
||||
def __repr_args__(self) -> _repr.ReprArgs:
|
||||
yield 'root', self.root
|
||||
5
venv/lib/python3.11/site-packages/pydantic/schema.py
Normal file
5
venv/lib/python3.11/site-packages/pydantic/schema.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""The `schema` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
5
venv/lib/python3.11/site-packages/pydantic/tools.py
Normal file
5
venv/lib/python3.11/site-packages/pydantic/tools.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""The `tools` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
594
venv/lib/python3.11/site-packages/pydantic/type_adapter.py
Normal file
594
venv/lib/python3.11/site-packages/pydantic/type_adapter.py
Normal file
@@ -0,0 +1,594 @@
|
||||
"""Type adapter specification."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import is_dataclass
|
||||
from functools import cached_property, wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Literal,
|
||||
TypeVar,
|
||||
cast,
|
||||
final,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator, Some
|
||||
from typing_extensions import Concatenate, ParamSpec, is_typeddict
|
||||
|
||||
from pydantic.errors import PydanticUserError
|
||||
from pydantic.main import BaseModel, IncEx
|
||||
|
||||
from ._internal import _config, _generate_schema, _mock_val_ser, _typing_extra, _utils
|
||||
from .config import ConfigDict
|
||||
from .json_schema import (
|
||||
DEFAULT_REF_TEMPLATE,
|
||||
GenerateJsonSchema,
|
||||
JsonSchemaKeyT,
|
||||
JsonSchemaMode,
|
||||
JsonSchemaValue,
|
||||
)
|
||||
from .plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
|
||||
|
||||
T = TypeVar('T')
|
||||
R = TypeVar('R')
|
||||
P = ParamSpec('P')
|
||||
TypeAdapterT = TypeVar('TypeAdapterT', bound='TypeAdapter')
|
||||
|
||||
|
||||
def _get_schema(type_: Any, config_wrapper: _config.ConfigWrapper, parent_depth: int) -> CoreSchema:
|
||||
"""`BaseModel` uses its own `__module__` to find out where it was defined
|
||||
and then looks for symbols to resolve forward references in those globals.
|
||||
On the other hand this function can be called with arbitrary objects,
|
||||
including type aliases, where `__module__` (always `typing.py`) is not useful.
|
||||
So instead we look at the globals in our parent stack frame.
|
||||
|
||||
This works for the case where this function is called in a module that
|
||||
has the target of forward references in its scope, but
|
||||
does not always work for more complex cases.
|
||||
|
||||
For example, take the following:
|
||||
|
||||
a.py
|
||||
```python
|
||||
from typing import Dict, List
|
||||
|
||||
IntList = List[int]
|
||||
OuterDict = Dict[str, 'IntList']
|
||||
```
|
||||
|
||||
b.py
|
||||
```python test="skip"
|
||||
from a import OuterDict
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
IntList = int # replaces the symbol the forward reference is looking for
|
||||
v = TypeAdapter(OuterDict)
|
||||
v({'x': 1}) # should fail but doesn't
|
||||
```
|
||||
|
||||
If `OuterDict` were a `BaseModel`, this would work because it would resolve
|
||||
the forward reference within the `a.py` namespace.
|
||||
But `TypeAdapter(OuterDict)` can't determine what module `OuterDict` came from.
|
||||
|
||||
In other words, the assumption that _all_ forward references exist in the
|
||||
module we are being called from is not technically always true.
|
||||
Although most of the time it is and it works fine for recursive models and such,
|
||||
`BaseModel`'s behavior isn't perfect either and _can_ break in similar ways,
|
||||
so there is no right or wrong between the two.
|
||||
|
||||
But at the very least this behavior is _subtly_ different from `BaseModel`'s.
|
||||
"""
|
||||
local_ns = _typing_extra.parent_frame_namespace(parent_depth=parent_depth)
|
||||
global_ns = sys._getframe(max(parent_depth - 1, 1)).f_globals.copy()
|
||||
global_ns.update(local_ns or {})
|
||||
gen = (config_wrapper.schema_generator or _generate_schema.GenerateSchema)(
|
||||
config_wrapper, types_namespace=global_ns, typevars_map={}
|
||||
)
|
||||
schema = gen.generate_schema(type_)
|
||||
schema = gen.clean_schema(schema)
|
||||
return schema
|
||||
|
||||
|
||||
def _getattr_no_parents(obj: Any, attribute: str) -> Any:
|
||||
"""Returns the attribute value without attempting to look up attributes from parent types."""
|
||||
if hasattr(obj, '__dict__'):
|
||||
try:
|
||||
return obj.__dict__[attribute]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
slots = getattr(obj, '__slots__', None)
|
||||
if slots is not None and attribute in slots:
|
||||
return getattr(obj, attribute)
|
||||
else:
|
||||
raise AttributeError(attribute)
|
||||
|
||||
|
||||
def _type_has_config(type_: Any) -> bool:
|
||||
"""Returns whether the type has config."""
|
||||
type_ = _typing_extra.annotated_type(type_) or type_
|
||||
try:
|
||||
return issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_)
|
||||
except TypeError:
|
||||
# type is not a class
|
||||
return False
|
||||
|
||||
|
||||
# This is keeping track of the frame depth for the TypeAdapter functions. This is required for _parent_depth used for
|
||||
# ForwardRef resolution. We may enter the TypeAdapter schema building via different TypeAdapter functions. Hence, we
|
||||
# need to keep track of the frame depth relative to the originally provided _parent_depth.
|
||||
def _frame_depth(
|
||||
depth: int,
|
||||
) -> Callable[[Callable[Concatenate[TypeAdapterT, P], R]], Callable[Concatenate[TypeAdapterT, P], R]]:
|
||||
def wrapper(func: Callable[Concatenate[TypeAdapterT, P], R]) -> Callable[Concatenate[TypeAdapterT, P], R]:
|
||||
@wraps(func)
|
||||
def wrapped(self: TypeAdapterT, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
with self._with_frame_depth(depth + 1): # depth + 1 for the wrapper function
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@final
|
||||
class TypeAdapter(Generic[T]):
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/type_adapter/
|
||||
|
||||
Type adapters provide a flexible way to perform validation and serialization based on a Python type.
|
||||
|
||||
A `TypeAdapter` instance exposes some of the functionality from `BaseModel` instance methods
|
||||
for types that do not have such methods (such as dataclasses, primitive types, and more).
|
||||
|
||||
**Note:** `TypeAdapter` instances are not types, and cannot be used as type annotations for fields.
|
||||
|
||||
**Note:** By default, `TypeAdapter` does not respect the
|
||||
[`defer_build=True`][pydantic.config.ConfigDict.defer_build] setting in the
|
||||
[`model_config`][pydantic.BaseModel.model_config] or in the `TypeAdapter` constructor `config`. You need to also
|
||||
explicitly set [`experimental_defer_build_mode=('model', 'type_adapter')`][pydantic.config.ConfigDict.experimental_defer_build_mode] of the
|
||||
config to defer the model validator and serializer construction. Thus, this feature is opt-in to ensure backwards
|
||||
compatibility.
|
||||
|
||||
Attributes:
|
||||
core_schema: The core schema for the type.
|
||||
validator (SchemaValidator): The schema validator for the type.
|
||||
serializer: The schema serializer for the type.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
type: type[T],
|
||||
*,
|
||||
config: ConfigDict | None = ...,
|
||||
_parent_depth: int = ...,
|
||||
module: str | None = ...,
|
||||
) -> None: ...
|
||||
|
||||
# This second overload is for unsupported special forms (such as Annotated, Union, etc.)
|
||||
# Currently there is no way to type this correctly
|
||||
# See https://github.com/python/typing/pull/1618
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
type: Any,
|
||||
*,
|
||||
config: ConfigDict | None = ...,
|
||||
_parent_depth: int = ...,
|
||||
module: str | None = ...,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Any,
|
||||
*,
|
||||
config: ConfigDict | None = None,
|
||||
_parent_depth: int = 2,
|
||||
module: str | None = None,
|
||||
) -> None:
|
||||
"""Initializes the TypeAdapter object.
|
||||
|
||||
Args:
|
||||
type: The type associated with the `TypeAdapter`.
|
||||
config: Configuration for the `TypeAdapter`, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict].
|
||||
_parent_depth: depth at which to search the parent namespace to construct the local namespace.
|
||||
module: The module that passes to plugin if provided.
|
||||
|
||||
!!! note
|
||||
You cannot use the `config` argument when instantiating a `TypeAdapter` if the type you're using has its own
|
||||
config that cannot be overridden (ex: `BaseModel`, `TypedDict`, and `dataclass`). A
|
||||
[`type-adapter-config-unused`](../errors/usage_errors.md#type-adapter-config-unused) error will be raised in this case.
|
||||
|
||||
!!! note
|
||||
The `_parent_depth` argument is named with an underscore to suggest its private nature and discourage use.
|
||||
It may be deprecated in a minor version, so we only recommend using it if you're
|
||||
comfortable with potential change in behavior / support.
|
||||
|
||||
??? tip "Compatibility with `mypy`"
|
||||
Depending on the type used, `mypy` might raise an error when instantiating a `TypeAdapter`. As a workaround, you can explicitly
|
||||
annotate your variable:
|
||||
|
||||
```py
|
||||
from typing import Union
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
ta: TypeAdapter[Union[str, int]] = TypeAdapter(Union[str, int]) # type: ignore[arg-type]
|
||||
```
|
||||
|
||||
Returns:
|
||||
A type adapter configured for the specified `type`.
|
||||
"""
|
||||
if _type_has_config(type) and config is not None:
|
||||
raise PydanticUserError(
|
||||
'Cannot use `config` when the type is a BaseModel, dataclass or TypedDict.'
|
||||
' These types can have their own config and setting the config via the `config`'
|
||||
' parameter to TypeAdapter will not override it, thus the `config` you passed to'
|
||||
' TypeAdapter becomes meaningless, which is probably not what you want.',
|
||||
code='type-adapter-config-unused',
|
||||
)
|
||||
|
||||
self._type = type
|
||||
self._config = config
|
||||
self._parent_depth = _parent_depth
|
||||
if module is None:
|
||||
f = sys._getframe(1)
|
||||
self._module_name = cast(str, f.f_globals.get('__name__', ''))
|
||||
else:
|
||||
self._module_name = module
|
||||
|
||||
self._core_schema: CoreSchema | None = None
|
||||
self._validator: SchemaValidator | PluggableSchemaValidator | None = None
|
||||
self._serializer: SchemaSerializer | None = None
|
||||
|
||||
if not self._defer_build():
|
||||
# Immediately initialize the core schema, validator and serializer
|
||||
with self._with_frame_depth(1): # +1 frame depth for this __init__
|
||||
# Model itself may be using deferred building. For backward compatibility we don't rebuild model mocks
|
||||
# here as part of __init__ even though TypeAdapter itself is not using deferred building.
|
||||
self._init_core_attrs(rebuild_mocks=False)
|
||||
|
||||
@contextmanager
|
||||
def _with_frame_depth(self, depth: int) -> Iterator[None]:
|
||||
self._parent_depth += depth
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._parent_depth -= depth
|
||||
|
||||
@_frame_depth(1)
|
||||
def _init_core_attrs(self, rebuild_mocks: bool) -> None:
|
||||
try:
|
||||
self._core_schema = _getattr_no_parents(self._type, '__pydantic_core_schema__')
|
||||
self._validator = _getattr_no_parents(self._type, '__pydantic_validator__')
|
||||
self._serializer = _getattr_no_parents(self._type, '__pydantic_serializer__')
|
||||
except AttributeError:
|
||||
config_wrapper = _config.ConfigWrapper(self._config)
|
||||
core_config = config_wrapper.core_config(None)
|
||||
|
||||
self._core_schema = _get_schema(self._type, config_wrapper, parent_depth=self._parent_depth)
|
||||
self._validator = create_schema_validator(
|
||||
schema=self._core_schema,
|
||||
schema_type=self._type,
|
||||
schema_type_module=self._module_name,
|
||||
schema_type_name=str(self._type),
|
||||
schema_kind='TypeAdapter',
|
||||
config=core_config,
|
||||
plugin_settings=config_wrapper.plugin_settings,
|
||||
)
|
||||
self._serializer = SchemaSerializer(self._core_schema, core_config)
|
||||
|
||||
if rebuild_mocks and isinstance(self._core_schema, _mock_val_ser.MockCoreSchema):
|
||||
self._core_schema.rebuild()
|
||||
self._init_core_attrs(rebuild_mocks=False)
|
||||
assert not isinstance(self._core_schema, _mock_val_ser.MockCoreSchema)
|
||||
assert not isinstance(self._validator, _mock_val_ser.MockValSer)
|
||||
assert not isinstance(self._serializer, _mock_val_ser.MockValSer)
|
||||
|
||||
@cached_property
|
||||
@_frame_depth(2) # +2 for @cached_property and core_schema(self)
|
||||
def core_schema(self) -> CoreSchema:
|
||||
"""The pydantic-core schema used to build the SchemaValidator and SchemaSerializer."""
|
||||
if self._core_schema is None or isinstance(self._core_schema, _mock_val_ser.MockCoreSchema):
|
||||
self._init_core_attrs(rebuild_mocks=True) # Do not expose MockCoreSchema from public function
|
||||
assert self._core_schema is not None and not isinstance(self._core_schema, _mock_val_ser.MockCoreSchema)
|
||||
return self._core_schema
|
||||
|
||||
@cached_property
|
||||
@_frame_depth(2) # +2 for @cached_property + validator(self)
|
||||
def validator(self) -> SchemaValidator | PluggableSchemaValidator:
|
||||
"""The pydantic-core SchemaValidator used to validate instances of the model."""
|
||||
if not isinstance(self._validator, (SchemaValidator, PluggableSchemaValidator)):
|
||||
self._init_core_attrs(rebuild_mocks=True) # Do not expose MockValSer from public function
|
||||
assert isinstance(self._validator, (SchemaValidator, PluggableSchemaValidator))
|
||||
return self._validator
|
||||
|
||||
@cached_property
|
||||
@_frame_depth(2) # +2 for @cached_property + serializer(self)
|
||||
def serializer(self) -> SchemaSerializer:
|
||||
"""The pydantic-core SchemaSerializer used to dump instances of the model."""
|
||||
if not isinstance(self._serializer, SchemaSerializer):
|
||||
self._init_core_attrs(rebuild_mocks=True) # Do not expose MockValSer from public function
|
||||
assert isinstance(self._serializer, SchemaSerializer)
|
||||
return self._serializer
|
||||
|
||||
def _defer_build(self) -> bool:
|
||||
config = self._config if self._config is not None else self._model_config()
|
||||
return self._is_defer_build_config(config) if config is not None else False
|
||||
|
||||
def _model_config(self) -> ConfigDict | None:
|
||||
type_: Any = _typing_extra.annotated_type(self._type) or self._type # Eg FastAPI heavily uses Annotated
|
||||
if _utils.lenient_issubclass(type_, BaseModel):
|
||||
return type_.model_config
|
||||
return getattr(type_, '__pydantic_config__', None)
|
||||
|
||||
@staticmethod
|
||||
def _is_defer_build_config(config: ConfigDict) -> bool:
|
||||
# TODO reevaluate this logic when we have a better understanding of how defer_build should work with TypeAdapter
|
||||
# Should we drop the special experimental_defer_build_mode check?
|
||||
return config.get('defer_build', False) is True and 'type_adapter' in config.get(
|
||||
'experimental_defer_build_mode', ()
|
||||
)
|
||||
|
||||
@_frame_depth(1)
|
||||
def validate_python(
|
||||
self,
|
||||
object: Any,
|
||||
/,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
from_attributes: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> T:
|
||||
"""Validate a Python object against the model.
|
||||
|
||||
Args:
|
||||
object: The Python object to validate against the model.
|
||||
strict: Whether to strictly check types.
|
||||
from_attributes: Whether to extract data from object attributes.
|
||||
context: Additional context to pass to the validator.
|
||||
|
||||
!!! note
|
||||
When using `TypeAdapter` with a Pydantic `dataclass`, the use of the `from_attributes`
|
||||
argument is not supported.
|
||||
|
||||
Returns:
|
||||
The validated object.
|
||||
"""
|
||||
return self.validator.validate_python(object, strict=strict, from_attributes=from_attributes, context=context)
|
||||
|
||||
@_frame_depth(1)
|
||||
def validate_json(
|
||||
self, data: str | bytes, /, *, strict: bool | None = None, context: dict[str, Any] | None = None
|
||||
) -> T:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/json/#json-parsing
|
||||
|
||||
Validate a JSON string or bytes against the model.
|
||||
|
||||
Args:
|
||||
data: The JSON data to validate against the model.
|
||||
strict: Whether to strictly check types.
|
||||
context: Additional context to use during validation.
|
||||
|
||||
Returns:
|
||||
The validated object.
|
||||
"""
|
||||
return self.validator.validate_json(data, strict=strict, context=context)
|
||||
|
||||
@_frame_depth(1)
|
||||
def validate_strings(self, obj: Any, /, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> T:
|
||||
"""Validate object contains string data against the model.
|
||||
|
||||
Args:
|
||||
obj: The object contains string data to validate.
|
||||
strict: Whether to strictly check types.
|
||||
context: Additional context to use during validation.
|
||||
|
||||
Returns:
|
||||
The validated object.
|
||||
"""
|
||||
return self.validator.validate_strings(obj, strict=strict, context=context)
|
||||
|
||||
@_frame_depth(1)
|
||||
def get_default_value(self, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> Some[T] | None:
|
||||
"""Get the default value for the wrapped type.
|
||||
|
||||
Args:
|
||||
strict: Whether to strictly check types.
|
||||
context: Additional context to pass to the validator.
|
||||
|
||||
Returns:
|
||||
The default value wrapped in a `Some` if there is one or None if not.
|
||||
"""
|
||||
return self.validator.get_default_value(strict=strict, context=context)
|
||||
|
||||
@_frame_depth(1)
|
||||
def dump_python(
|
||||
self,
|
||||
instance: T,
|
||||
/,
|
||||
*,
|
||||
mode: Literal['json', 'python'] = 'python',
|
||||
include: IncEx | None = None,
|
||||
exclude: IncEx | None = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal['none', 'warn', 'error'] = True,
|
||||
serialize_as_any: bool = False,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
"""Dump an instance of the adapted type to a Python object.
|
||||
|
||||
Args:
|
||||
instance: The Python object to serialize.
|
||||
mode: The output format.
|
||||
include: Fields to include in the output.
|
||||
exclude: Fields to exclude from the output.
|
||||
by_alias: Whether to use alias names for field names.
|
||||
exclude_unset: Whether to exclude unset fields.
|
||||
exclude_defaults: Whether to exclude fields with default values.
|
||||
exclude_none: Whether to exclude fields with None values.
|
||||
round_trip: Whether to output the serialized data in a way that is compatible with deserialization.
|
||||
warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
|
||||
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
|
||||
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
|
||||
context: Additional context to pass to the serializer.
|
||||
|
||||
Returns:
|
||||
The serialized object.
|
||||
"""
|
||||
return self.serializer.to_python(
|
||||
instance,
|
||||
mode=mode,
|
||||
by_alias=by_alias,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
round_trip=round_trip,
|
||||
warnings=warnings,
|
||||
serialize_as_any=serialize_as_any,
|
||||
context=context,
|
||||
)
|
||||
|
||||
@_frame_depth(1)
|
||||
def dump_json(
|
||||
self,
|
||||
instance: T,
|
||||
/,
|
||||
*,
|
||||
indent: int | None = None,
|
||||
include: IncEx | None = None,
|
||||
exclude: IncEx | None = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal['none', 'warn', 'error'] = True,
|
||||
serialize_as_any: bool = False,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> bytes:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.9/concepts/json/#json-serialization
|
||||
|
||||
Serialize an instance of the adapted type to JSON.
|
||||
|
||||
Args:
|
||||
instance: The instance to be serialized.
|
||||
indent: Number of spaces for JSON indentation.
|
||||
include: Fields to include.
|
||||
exclude: Fields to exclude.
|
||||
by_alias: Whether to use alias names for field names.
|
||||
exclude_unset: Whether to exclude unset fields.
|
||||
exclude_defaults: Whether to exclude fields with default values.
|
||||
exclude_none: Whether to exclude fields with a value of `None`.
|
||||
round_trip: Whether to serialize and deserialize the instance to ensure round-tripping.
|
||||
warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
|
||||
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
|
||||
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
|
||||
context: Additional context to pass to the serializer.
|
||||
|
||||
Returns:
|
||||
The JSON representation of the given instance as bytes.
|
||||
"""
|
||||
return self.serializer.to_json(
|
||||
instance,
|
||||
indent=indent,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
round_trip=round_trip,
|
||||
warnings=warnings,
|
||||
serialize_as_any=serialize_as_any,
|
||||
context=context,
|
||||
)
|
||||
|
||||
@_frame_depth(1)
|
||||
def json_schema(
|
||||
self,
|
||||
*,
|
||||
by_alias: bool = True,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
mode: JsonSchemaMode = 'validation',
|
||||
) -> dict[str, Any]:
|
||||
"""Generate a JSON schema for the adapted type.
|
||||
|
||||
Args:
|
||||
by_alias: Whether to use alias names for field names.
|
||||
ref_template: The format string used for generating $ref strings.
|
||||
schema_generator: The generator class used for creating the schema.
|
||||
mode: The mode to use for schema generation.
|
||||
|
||||
Returns:
|
||||
The JSON schema for the model as a dictionary.
|
||||
"""
|
||||
schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
|
||||
return schema_generator_instance.generate(self.core_schema, mode=mode)
|
||||
|
||||
@staticmethod
|
||||
def json_schemas(
|
||||
inputs: Iterable[tuple[JsonSchemaKeyT, JsonSchemaMode, TypeAdapter[Any]]],
|
||||
/,
|
||||
*,
|
||||
by_alias: bool = True,
|
||||
title: str | None = None,
|
||||
description: str | None = None,
|
||||
ref_template: str = DEFAULT_REF_TEMPLATE,
|
||||
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
|
||||
) -> tuple[dict[tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], JsonSchemaValue]:
|
||||
"""Generate a JSON schema including definitions from multiple type adapters.
|
||||
|
||||
Args:
|
||||
inputs: Inputs to schema generation. The first two items will form the keys of the (first)
|
||||
output mapping; the type adapters will provide the core schemas that get converted into
|
||||
definitions in the output JSON schema.
|
||||
by_alias: Whether to use alias names.
|
||||
title: The title for the schema.
|
||||
description: The description for the schema.
|
||||
ref_template: The format string used for generating $ref strings.
|
||||
schema_generator: The generator class used for creating the schema.
|
||||
|
||||
Returns:
|
||||
A tuple where:
|
||||
|
||||
- The first element is a dictionary whose keys are tuples of JSON schema key type and JSON mode, and
|
||||
whose values are the JSON schema corresponding to that pair of inputs. (These schemas may have
|
||||
JsonRef references to definitions that are defined in the second returned element.)
|
||||
- The second element is a JSON schema containing all definitions referenced in the first returned
|
||||
element, along with the optional title and description keys.
|
||||
|
||||
"""
|
||||
schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
|
||||
|
||||
inputs_ = []
|
||||
for key, mode, adapter in inputs:
|
||||
with adapter._with_frame_depth(1): # +1 for json_schemas staticmethod
|
||||
inputs_.append((key, mode, adapter.core_schema))
|
||||
|
||||
json_schemas_map, definitions = schema_generator_instance.generate_definitions(inputs_)
|
||||
|
||||
json_schema: dict[str, Any] = {}
|
||||
if definitions:
|
||||
json_schema['$defs'] = definitions
|
||||
if title:
|
||||
json_schema['title'] = title
|
||||
if description:
|
||||
json_schema['description'] = description
|
||||
|
||||
return json_schemas_map, json_schema
|
||||
3074
venv/lib/python3.11/site-packages/pydantic/types.py
Normal file
3074
venv/lib/python3.11/site-packages/pydantic/types.py
Normal file
File diff suppressed because it is too large
Load Diff
5
venv/lib/python3.11/site-packages/pydantic/typing.py
Normal file
5
venv/lib/python3.11/site-packages/pydantic/typing.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""`typing` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
5
venv/lib/python3.11/site-packages/pydantic/utils.py
Normal file
5
venv/lib/python3.11/site-packages/pydantic/utils.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""The `utils` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
131
venv/lib/python3.11/site-packages/pydantic/v1/__init__.py
Normal file
131
venv/lib/python3.11/site-packages/pydantic/v1/__init__.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# flake8: noqa
|
||||
from pydantic.v1 import dataclasses
|
||||
from pydantic.v1.annotated_types import create_model_from_namedtuple, create_model_from_typeddict
|
||||
from pydantic.v1.class_validators import root_validator, validator
|
||||
from pydantic.v1.config import BaseConfig, ConfigDict, Extra
|
||||
from pydantic.v1.decorator import validate_arguments
|
||||
from pydantic.v1.env_settings import BaseSettings
|
||||
from pydantic.v1.error_wrappers import ValidationError
|
||||
from pydantic.v1.errors import *
|
||||
from pydantic.v1.fields import Field, PrivateAttr, Required
|
||||
from pydantic.v1.main import *
|
||||
from pydantic.v1.networks import *
|
||||
from pydantic.v1.parse import Protocol
|
||||
from pydantic.v1.tools import *
|
||||
from pydantic.v1.types import *
|
||||
from pydantic.v1.version import VERSION, compiled
|
||||
|
||||
__version__ = VERSION
|
||||
|
||||
# WARNING __all__ from pydantic.errors is not included here, it will be removed as an export here in v2
|
||||
# please use "from pydantic.v1.errors import ..." instead
|
||||
__all__ = [
|
||||
# annotated types utils
|
||||
'create_model_from_namedtuple',
|
||||
'create_model_from_typeddict',
|
||||
# dataclasses
|
||||
'dataclasses',
|
||||
# class_validators
|
||||
'root_validator',
|
||||
'validator',
|
||||
# config
|
||||
'BaseConfig',
|
||||
'ConfigDict',
|
||||
'Extra',
|
||||
# decorator
|
||||
'validate_arguments',
|
||||
# env_settings
|
||||
'BaseSettings',
|
||||
# error_wrappers
|
||||
'ValidationError',
|
||||
# fields
|
||||
'Field',
|
||||
'Required',
|
||||
# main
|
||||
'BaseModel',
|
||||
'create_model',
|
||||
'validate_model',
|
||||
# network
|
||||
'AnyUrl',
|
||||
'AnyHttpUrl',
|
||||
'FileUrl',
|
||||
'HttpUrl',
|
||||
'stricturl',
|
||||
'EmailStr',
|
||||
'NameEmail',
|
||||
'IPvAnyAddress',
|
||||
'IPvAnyInterface',
|
||||
'IPvAnyNetwork',
|
||||
'PostgresDsn',
|
||||
'CockroachDsn',
|
||||
'AmqpDsn',
|
||||
'RedisDsn',
|
||||
'MongoDsn',
|
||||
'KafkaDsn',
|
||||
'validate_email',
|
||||
# parse
|
||||
'Protocol',
|
||||
# tools
|
||||
'parse_file_as',
|
||||
'parse_obj_as',
|
||||
'parse_raw_as',
|
||||
'schema_of',
|
||||
'schema_json_of',
|
||||
# types
|
||||
'NoneStr',
|
||||
'NoneBytes',
|
||||
'StrBytes',
|
||||
'NoneStrBytes',
|
||||
'StrictStr',
|
||||
'ConstrainedBytes',
|
||||
'conbytes',
|
||||
'ConstrainedList',
|
||||
'conlist',
|
||||
'ConstrainedSet',
|
||||
'conset',
|
||||
'ConstrainedFrozenSet',
|
||||
'confrozenset',
|
||||
'ConstrainedStr',
|
||||
'constr',
|
||||
'PyObject',
|
||||
'ConstrainedInt',
|
||||
'conint',
|
||||
'PositiveInt',
|
||||
'NegativeInt',
|
||||
'NonNegativeInt',
|
||||
'NonPositiveInt',
|
||||
'ConstrainedFloat',
|
||||
'confloat',
|
||||
'PositiveFloat',
|
||||
'NegativeFloat',
|
||||
'NonNegativeFloat',
|
||||
'NonPositiveFloat',
|
||||
'FiniteFloat',
|
||||
'ConstrainedDecimal',
|
||||
'condecimal',
|
||||
'ConstrainedDate',
|
||||
'condate',
|
||||
'UUID1',
|
||||
'UUID3',
|
||||
'UUID4',
|
||||
'UUID5',
|
||||
'FilePath',
|
||||
'DirectoryPath',
|
||||
'Json',
|
||||
'JsonWrapper',
|
||||
'SecretField',
|
||||
'SecretStr',
|
||||
'SecretBytes',
|
||||
'StrictBool',
|
||||
'StrictBytes',
|
||||
'StrictInt',
|
||||
'StrictFloat',
|
||||
'PaymentCardNumber',
|
||||
'PrivateAttr',
|
||||
'ByteSize',
|
||||
'PastDate',
|
||||
'FutureDate',
|
||||
# version
|
||||
'compiled',
|
||||
'VERSION',
|
||||
]
|
||||
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
Register Hypothesis strategies for Pydantic custom types.
|
||||
|
||||
This enables fully-automatic generation of test data for most Pydantic classes.
|
||||
|
||||
Note that this module has *no* runtime impact on Pydantic itself; instead it
|
||||
is registered as a setuptools entry point and Hypothesis will import it if
|
||||
Pydantic is installed. See also:
|
||||
|
||||
https://hypothesis.readthedocs.io/en/latest/strategies.html#registering-strategies-via-setuptools-entry-points
|
||||
https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.register_type_strategy
|
||||
https://hypothesis.readthedocs.io/en/latest/strategies.html#interaction-with-pytest-cov
|
||||
https://docs.pydantic.dev/usage/types/#pydantic-types
|
||||
|
||||
Note that because our motivation is to *improve user experience*, the strategies
|
||||
are always sound (never generate invalid data) but sacrifice completeness for
|
||||
maintainability (ie may be unable to generate some tricky but valid data).
|
||||
|
||||
Finally, this module makes liberal use of `# type: ignore[<code>]` pragmas.
|
||||
This is because Hypothesis annotates `register_type_strategy()` with
|
||||
`(T, SearchStrategy[T])`, but in most cases we register e.g. `ConstrainedInt`
|
||||
to generate instances of the builtin `int` type which match the constraints.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import ipaddress
|
||||
import json
|
||||
import math
|
||||
from fractions import Fraction
|
||||
from typing import Callable, Dict, Type, Union, cast, overload
|
||||
|
||||
import hypothesis.strategies as st
|
||||
|
||||
import pydantic
|
||||
import pydantic.color
|
||||
import pydantic.types
|
||||
from pydantic.v1.utils import lenient_issubclass
|
||||
|
||||
# FilePath and DirectoryPath are explicitly unsupported, as we'd have to create
|
||||
# them on-disk, and that's unsafe in general without being told *where* to do so.
|
||||
#
|
||||
# URLs are unsupported because it's easy for users to define their own strategy for
|
||||
# "normal" URLs, and hard for us to define a general strategy which includes "weird"
|
||||
# URLs but doesn't also have unpredictable performance problems.
|
||||
#
|
||||
# conlist() and conset() are unsupported for now, because the workarounds for
|
||||
# Cython and Hypothesis to handle parametrized generic types are incompatible.
|
||||
# We are rethinking Hypothesis compatibility in Pydantic v2.
|
||||
|
||||
# Emails
|
||||
try:
|
||||
import email_validator
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
else:
|
||||
|
||||
def is_valid_email(s: str) -> bool:
|
||||
# Hypothesis' st.emails() occasionally generates emails like 0@A0--0.ac
|
||||
# that are invalid according to email-validator, so we filter those out.
|
||||
try:
|
||||
email_validator.validate_email(s, check_deliverability=False)
|
||||
return True
|
||||
except email_validator.EmailNotValidError: # pragma: no cover
|
||||
return False
|
||||
|
||||
# Note that these strategies deliberately stay away from any tricky Unicode
|
||||
# or other encoding issues; we're just trying to generate *something* valid.
|
||||
st.register_type_strategy(pydantic.EmailStr, st.emails().filter(is_valid_email)) # type: ignore[arg-type]
|
||||
st.register_type_strategy(
|
||||
pydantic.NameEmail,
|
||||
st.builds(
|
||||
'{} <{}>'.format, # type: ignore[arg-type]
|
||||
st.from_regex('[A-Za-z0-9_]+( [A-Za-z0-9_]+){0,5}', fullmatch=True),
|
||||
st.emails().filter(is_valid_email),
|
||||
),
|
||||
)
|
||||
|
||||
# PyObject - dotted names, in this case taken from the math module.
|
||||
st.register_type_strategy(
|
||||
pydantic.PyObject, # type: ignore[arg-type]
|
||||
st.sampled_from(
|
||||
[cast(pydantic.PyObject, f'math.{name}') for name in sorted(vars(math)) if not name.startswith('_')]
|
||||
),
|
||||
)
|
||||
|
||||
# CSS3 Colors; as name, hex, rgb(a) tuples or strings, or hsl strings
|
||||
_color_regexes = (
|
||||
'|'.join(
|
||||
(
|
||||
pydantic.color.r_hex_short,
|
||||
pydantic.color.r_hex_long,
|
||||
pydantic.color.r_rgb,
|
||||
pydantic.color.r_rgba,
|
||||
pydantic.color.r_hsl,
|
||||
pydantic.color.r_hsla,
|
||||
)
|
||||
)
|
||||
# Use more precise regex patterns to avoid value-out-of-range errors
|
||||
.replace(pydantic.color._r_sl, r'(?:(\d\d?(?:\.\d+)?|100(?:\.0+)?)%)')
|
||||
.replace(pydantic.color._r_alpha, r'(?:(0(?:\.\d+)?|1(?:\.0+)?|\.\d+|\d{1,2}%))')
|
||||
.replace(pydantic.color._r_255, r'(?:((?:\d|\d\d|[01]\d\d|2[0-4]\d|25[0-4])(?:\.\d+)?|255(?:\.0+)?))')
|
||||
)
|
||||
st.register_type_strategy(
|
||||
pydantic.color.Color,
|
||||
st.one_of(
|
||||
st.sampled_from(sorted(pydantic.color.COLORS_BY_NAME)),
|
||||
st.tuples(
|
||||
st.integers(0, 255),
|
||||
st.integers(0, 255),
|
||||
st.integers(0, 255),
|
||||
st.none() | st.floats(0, 1) | st.floats(0, 100).map('{}%'.format),
|
||||
),
|
||||
st.from_regex(_color_regexes, fullmatch=True),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Card numbers, valid according to the Luhn algorithm
|
||||
|
||||
|
||||
def add_luhn_digit(card_number: str) -> str:
|
||||
# See https://en.wikipedia.org/wiki/Luhn_algorithm
|
||||
for digit in '0123456789':
|
||||
with contextlib.suppress(Exception):
|
||||
pydantic.PaymentCardNumber.validate_luhn_check_digit(card_number + digit)
|
||||
return card_number + digit
|
||||
raise AssertionError('Unreachable') # pragma: no cover
|
||||
|
||||
|
||||
card_patterns = (
|
||||
# Note that these patterns omit the Luhn check digit; that's added by the function above
|
||||
'4[0-9]{14}', # Visa
|
||||
'5[12345][0-9]{13}', # Mastercard
|
||||
'3[47][0-9]{12}', # American Express
|
||||
'[0-26-9][0-9]{10,17}', # other (incomplete to avoid overlap)
|
||||
)
|
||||
st.register_type_strategy(
|
||||
pydantic.PaymentCardNumber,
|
||||
st.from_regex('|'.join(card_patterns), fullmatch=True).map(add_luhn_digit), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# UUIDs
|
||||
st.register_type_strategy(pydantic.UUID1, st.uuids(version=1))
|
||||
st.register_type_strategy(pydantic.UUID3, st.uuids(version=3))
|
||||
st.register_type_strategy(pydantic.UUID4, st.uuids(version=4))
|
||||
st.register_type_strategy(pydantic.UUID5, st.uuids(version=5))
|
||||
|
||||
# Secrets
|
||||
st.register_type_strategy(pydantic.SecretBytes, st.binary().map(pydantic.SecretBytes))
|
||||
st.register_type_strategy(pydantic.SecretStr, st.text().map(pydantic.SecretStr))
|
||||
|
||||
# IP addresses, networks, and interfaces
|
||||
st.register_type_strategy(pydantic.IPvAnyAddress, st.ip_addresses()) # type: ignore[arg-type]
|
||||
st.register_type_strategy(
|
||||
pydantic.IPvAnyInterface,
|
||||
st.from_type(ipaddress.IPv4Interface) | st.from_type(ipaddress.IPv6Interface), # type: ignore[arg-type]
|
||||
)
|
||||
st.register_type_strategy(
|
||||
pydantic.IPvAnyNetwork,
|
||||
st.from_type(ipaddress.IPv4Network) | st.from_type(ipaddress.IPv6Network), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# We hook into the con***() functions and the ConstrainedNumberMeta metaclass,
|
||||
# so here we only have to register subclasses for other constrained types which
|
||||
# don't go via those mechanisms. Then there are the registration hooks below.
|
||||
st.register_type_strategy(pydantic.StrictBool, st.booleans())
|
||||
st.register_type_strategy(pydantic.StrictStr, st.text())
|
||||
|
||||
|
||||
# FutureDate, PastDate
|
||||
st.register_type_strategy(pydantic.FutureDate, st.dates(min_value=datetime.date.today() + datetime.timedelta(days=1)))
|
||||
st.register_type_strategy(pydantic.PastDate, st.dates(max_value=datetime.date.today() - datetime.timedelta(days=1)))
|
||||
|
||||
|
||||
# Constrained-type resolver functions
|
||||
#
|
||||
# For these ones, we actually want to inspect the type in order to work out a
|
||||
# satisfying strategy. First up, the machinery for tracking resolver functions:
|
||||
|
||||
RESOLVERS: Dict[type, Callable[[type], st.SearchStrategy]] = {} # type: ignore[type-arg]
|
||||
|
||||
|
||||
@overload
|
||||
def _registered(typ: Type[pydantic.types.T]) -> Type[pydantic.types.T]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def _registered(typ: pydantic.types.ConstrainedNumberMeta) -> pydantic.types.ConstrainedNumberMeta:
|
||||
pass
|
||||
|
||||
|
||||
def _registered(
|
||||
typ: Union[Type[pydantic.types.T], pydantic.types.ConstrainedNumberMeta]
|
||||
) -> Union[Type[pydantic.types.T], pydantic.types.ConstrainedNumberMeta]:
|
||||
# This function replaces the version in `pydantic.types`, in order to
|
||||
# effect the registration of new constrained types so that Hypothesis
|
||||
# can generate valid examples.
|
||||
pydantic.types._DEFINED_TYPES.add(typ)
|
||||
for supertype, resolver in RESOLVERS.items():
|
||||
if issubclass(typ, supertype):
|
||||
st.register_type_strategy(typ, resolver(typ)) # type: ignore
|
||||
return typ
|
||||
raise NotImplementedError(f'Unknown type {typ!r} has no resolver to register') # pragma: no cover
|
||||
|
||||
|
||||
def resolves(
|
||||
typ: Union[type, pydantic.types.ConstrainedNumberMeta]
|
||||
) -> Callable[[Callable[..., st.SearchStrategy]], Callable[..., st.SearchStrategy]]: # type: ignore[type-arg]
|
||||
def inner(f): # type: ignore
|
||||
assert f not in RESOLVERS
|
||||
RESOLVERS[typ] = f
|
||||
return f
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
# Type-to-strategy resolver functions
|
||||
|
||||
|
||||
@resolves(pydantic.JsonWrapper)
|
||||
def resolve_json(cls): # type: ignore[no-untyped-def]
|
||||
try:
|
||||
inner = st.none() if cls.inner_type is None else st.from_type(cls.inner_type)
|
||||
except Exception: # pragma: no cover
|
||||
finite = st.floats(allow_infinity=False, allow_nan=False)
|
||||
inner = st.recursive(
|
||||
base=st.one_of(st.none(), st.booleans(), st.integers(), finite, st.text()),
|
||||
extend=lambda x: st.lists(x) | st.dictionaries(st.text(), x), # type: ignore
|
||||
)
|
||||
inner_type = getattr(cls, 'inner_type', None)
|
||||
return st.builds(
|
||||
cls.inner_type.json if lenient_issubclass(inner_type, pydantic.BaseModel) else json.dumps,
|
||||
inner,
|
||||
ensure_ascii=st.booleans(),
|
||||
indent=st.none() | st.integers(0, 16),
|
||||
sort_keys=st.booleans(),
|
||||
)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedBytes)
|
||||
def resolve_conbytes(cls): # type: ignore[no-untyped-def] # pragma: no cover
|
||||
min_size = cls.min_length or 0
|
||||
max_size = cls.max_length
|
||||
if not cls.strip_whitespace:
|
||||
return st.binary(min_size=min_size, max_size=max_size)
|
||||
# Fun with regex to ensure we neither start nor end with whitespace
|
||||
repeats = '{{{},{}}}'.format(
|
||||
min_size - 2 if min_size > 2 else 0,
|
||||
max_size - 2 if (max_size or 0) > 2 else '',
|
||||
)
|
||||
if min_size >= 2:
|
||||
pattern = rf'\W.{repeats}\W'
|
||||
elif min_size == 1:
|
||||
pattern = rf'\W(.{repeats}\W)?'
|
||||
else:
|
||||
assert min_size == 0
|
||||
pattern = rf'(\W(.{repeats}\W)?)?'
|
||||
return st.from_regex(pattern.encode(), fullmatch=True)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedDecimal)
|
||||
def resolve_condecimal(cls): # type: ignore[no-untyped-def]
|
||||
min_value = cls.ge
|
||||
max_value = cls.le
|
||||
if cls.gt is not None:
|
||||
assert min_value is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.gt
|
||||
if cls.lt is not None:
|
||||
assert max_value is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.lt
|
||||
s = st.decimals(min_value, max_value, allow_nan=False, places=cls.decimal_places)
|
||||
if cls.lt is not None:
|
||||
s = s.filter(lambda d: d < cls.lt)
|
||||
if cls.gt is not None:
|
||||
s = s.filter(lambda d: cls.gt < d)
|
||||
return s
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedFloat)
|
||||
def resolve_confloat(cls): # type: ignore[no-untyped-def]
|
||||
min_value = cls.ge
|
||||
max_value = cls.le
|
||||
exclude_min = False
|
||||
exclude_max = False
|
||||
|
||||
if cls.gt is not None:
|
||||
assert min_value is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.gt
|
||||
exclude_min = True
|
||||
if cls.lt is not None:
|
||||
assert max_value is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.lt
|
||||
exclude_max = True
|
||||
|
||||
if cls.multiple_of is None:
|
||||
return st.floats(min_value, max_value, exclude_min=exclude_min, exclude_max=exclude_max, allow_nan=False)
|
||||
|
||||
if min_value is not None:
|
||||
min_value = math.ceil(min_value / cls.multiple_of)
|
||||
if exclude_min:
|
||||
min_value = min_value + 1
|
||||
if max_value is not None:
|
||||
assert max_value >= cls.multiple_of, 'Cannot build model with max value smaller than multiple of'
|
||||
max_value = math.floor(max_value / cls.multiple_of)
|
||||
if exclude_max:
|
||||
max_value = max_value - 1
|
||||
|
||||
return st.integers(min_value, max_value).map(lambda x: x * cls.multiple_of)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedInt)
|
||||
def resolve_conint(cls): # type: ignore[no-untyped-def]
|
||||
min_value = cls.ge
|
||||
max_value = cls.le
|
||||
if cls.gt is not None:
|
||||
assert min_value is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.gt + 1
|
||||
if cls.lt is not None:
|
||||
assert max_value is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.lt - 1
|
||||
|
||||
if cls.multiple_of is None or cls.multiple_of == 1:
|
||||
return st.integers(min_value, max_value)
|
||||
|
||||
# These adjustments and the .map handle integer-valued multiples, while the
|
||||
# .filter handles trickier cases as for confloat.
|
||||
if min_value is not None:
|
||||
min_value = math.ceil(Fraction(min_value) / Fraction(cls.multiple_of))
|
||||
if max_value is not None:
|
||||
max_value = math.floor(Fraction(max_value) / Fraction(cls.multiple_of))
|
||||
return st.integers(min_value, max_value).map(lambda x: x * cls.multiple_of)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedDate)
|
||||
def resolve_condate(cls): # type: ignore[no-untyped-def]
|
||||
if cls.ge is not None:
|
||||
assert cls.gt is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.ge
|
||||
elif cls.gt is not None:
|
||||
min_value = cls.gt + datetime.timedelta(days=1)
|
||||
else:
|
||||
min_value = datetime.date.min
|
||||
if cls.le is not None:
|
||||
assert cls.lt is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.le
|
||||
elif cls.lt is not None:
|
||||
max_value = cls.lt - datetime.timedelta(days=1)
|
||||
else:
|
||||
max_value = datetime.date.max
|
||||
return st.dates(min_value, max_value)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedStr)
|
||||
def resolve_constr(cls): # type: ignore[no-untyped-def] # pragma: no cover
|
||||
min_size = cls.min_length or 0
|
||||
max_size = cls.max_length
|
||||
|
||||
if cls.regex is None and not cls.strip_whitespace:
|
||||
return st.text(min_size=min_size, max_size=max_size)
|
||||
|
||||
if cls.regex is not None:
|
||||
strategy = st.from_regex(cls.regex)
|
||||
if cls.strip_whitespace:
|
||||
strategy = strategy.filter(lambda s: s == s.strip())
|
||||
elif cls.strip_whitespace:
|
||||
repeats = '{{{},{}}}'.format(
|
||||
min_size - 2 if min_size > 2 else 0,
|
||||
max_size - 2 if (max_size or 0) > 2 else '',
|
||||
)
|
||||
if min_size >= 2:
|
||||
strategy = st.from_regex(rf'\W.{repeats}\W')
|
||||
elif min_size == 1:
|
||||
strategy = st.from_regex(rf'\W(.{repeats}\W)?')
|
||||
else:
|
||||
assert min_size == 0
|
||||
strategy = st.from_regex(rf'(\W(.{repeats}\W)?)?')
|
||||
|
||||
if min_size == 0 and max_size is None:
|
||||
return strategy
|
||||
elif max_size is None:
|
||||
return strategy.filter(lambda s: min_size <= len(s))
|
||||
return strategy.filter(lambda s: min_size <= len(s) <= max_size)
|
||||
|
||||
|
||||
# Finally, register all previously-defined types, and patch in our new function
|
||||
for typ in list(pydantic.types._DEFINED_TYPES):
|
||||
_registered(typ)
|
||||
pydantic.types._registered = _registered
|
||||
st.register_type_strategy(pydantic.Json, resolve_json)
|
||||
@@ -0,0 +1,72 @@
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, NamedTuple, Type
|
||||
|
||||
from pydantic.v1.fields import Required
|
||||
from pydantic.v1.main import BaseModel, create_model
|
||||
from pydantic.v1.typing import is_typeddict, is_typeddict_special
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
|
||||
def is_legacy_typeddict(typeddict_cls: Type['TypedDict']) -> bool: # type: ignore[valid-type]
|
||||
return is_typeddict(typeddict_cls) and type(typeddict_cls).__module__ == 'typing'
|
||||
|
||||
else:
|
||||
|
||||
def is_legacy_typeddict(_: Any) -> Any:
|
||||
return False
|
||||
|
||||
|
||||
def create_model_from_typeddict(
|
||||
# Mypy bug: `Type[TypedDict]` is resolved as `Any` https://github.com/python/mypy/issues/11030
|
||||
typeddict_cls: Type['TypedDict'], # type: ignore[valid-type]
|
||||
**kwargs: Any,
|
||||
) -> Type['BaseModel']:
|
||||
"""
|
||||
Create a `BaseModel` based on the fields of a `TypedDict`.
|
||||
Since `typing.TypedDict` in Python 3.8 does not store runtime information about optional keys,
|
||||
we raise an error if this happens (see https://bugs.python.org/issue38834).
|
||||
"""
|
||||
field_definitions: Dict[str, Any]
|
||||
|
||||
# Best case scenario: with python 3.9+ or when `TypedDict` is imported from `typing_extensions`
|
||||
if not hasattr(typeddict_cls, '__required_keys__'):
|
||||
raise TypeError(
|
||||
'You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.9.2. '
|
||||
'Without it, there is no way to differentiate required and optional fields when subclassed.'
|
||||
)
|
||||
|
||||
if is_legacy_typeddict(typeddict_cls) and any(
|
||||
is_typeddict_special(t) for t in typeddict_cls.__annotations__.values()
|
||||
):
|
||||
raise TypeError(
|
||||
'You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.11. '
|
||||
'Without it, there is no way to reflect Required/NotRequired keys.'
|
||||
)
|
||||
|
||||
required_keys: FrozenSet[str] = typeddict_cls.__required_keys__ # type: ignore[attr-defined]
|
||||
field_definitions = {
|
||||
field_name: (field_type, Required if field_name in required_keys else None)
|
||||
for field_name, field_type in typeddict_cls.__annotations__.items()
|
||||
}
|
||||
|
||||
return create_model(typeddict_cls.__name__, **kwargs, **field_definitions)
|
||||
|
||||
|
||||
def create_model_from_namedtuple(namedtuple_cls: Type['NamedTuple'], **kwargs: Any) -> Type['BaseModel']:
|
||||
"""
|
||||
Create a `BaseModel` based on the fields of a named tuple.
|
||||
A named tuple can be created with `typing.NamedTuple` and declared annotations
|
||||
but also with `collections.namedtuple`, in this case we consider all fields
|
||||
to have type `Any`.
|
||||
"""
|
||||
# With python 3.10+, `__annotations__` always exists but can be empty hence the `getattr... or...` logic
|
||||
namedtuple_annotations: Dict[str, Type[Any]] = getattr(namedtuple_cls, '__annotations__', None) or {
|
||||
k: Any for k in namedtuple_cls._fields
|
||||
}
|
||||
field_definitions: Dict[str, Any] = {
|
||||
field_name: (field_type, Required) for field_name, field_type in namedtuple_annotations.items()
|
||||
}
|
||||
return create_model(namedtuple_cls.__name__, **kwargs, **field_definitions)
|
||||
@@ -0,0 +1,361 @@
|
||||
import warnings
|
||||
from collections import ChainMap
|
||||
from functools import partial, partialmethod, wraps
|
||||
from itertools import chain
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload
|
||||
|
||||
from pydantic.v1.errors import ConfigError
|
||||
from pydantic.v1.typing import AnyCallable
|
||||
from pydantic.v1.utils import ROOT_KEY, in_ipython
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import AnyClassMethod
|
||||
|
||||
|
||||
class Validator:
|
||||
__slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields', 'skip_on_failure'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: AnyCallable,
|
||||
pre: bool = False,
|
||||
each_item: bool = False,
|
||||
always: bool = False,
|
||||
check_fields: bool = False,
|
||||
skip_on_failure: bool = False,
|
||||
):
|
||||
self.func = func
|
||||
self.pre = pre
|
||||
self.each_item = each_item
|
||||
self.always = always
|
||||
self.check_fields = check_fields
|
||||
self.skip_on_failure = skip_on_failure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inspect import Signature
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.types import ModelOrDc
|
||||
|
||||
ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any]
|
||||
ValidatorsList = List[ValidatorCallable]
|
||||
ValidatorListDict = Dict[str, List[Validator]]
|
||||
|
||||
_FUNCS: Set[str] = set()
|
||||
VALIDATOR_CONFIG_KEY = '__validator_config__'
|
||||
ROOT_VALIDATOR_CONFIG_KEY = '__root_validator_config__'
|
||||
|
||||
|
||||
def validator(
|
||||
*fields: str,
|
||||
pre: bool = False,
|
||||
each_item: bool = False,
|
||||
always: bool = False,
|
||||
check_fields: bool = True,
|
||||
whole: Optional[bool] = None,
|
||||
allow_reuse: bool = False,
|
||||
) -> Callable[[AnyCallable], 'AnyClassMethod']:
|
||||
"""
|
||||
Decorate methods on the class indicating that they should be used to validate fields
|
||||
:param fields: which field(s) the method should be called on
|
||||
:param pre: whether or not this validator should be called before the standard validators (else after)
|
||||
:param each_item: for complex objects (sets, lists etc.) whether to validate individual elements rather than the
|
||||
whole object
|
||||
:param always: whether this method and other validators should be called even if the value is missing
|
||||
:param check_fields: whether to check that the fields actually exist on the model
|
||||
:param allow_reuse: whether to track and raise an error if another validator refers to the decorated function
|
||||
"""
|
||||
if not fields:
|
||||
raise ConfigError('validator with no fields specified')
|
||||
elif isinstance(fields[0], FunctionType):
|
||||
raise ConfigError(
|
||||
"validators should be used with fields and keyword arguments, not bare. " # noqa: Q000
|
||||
"E.g. usage should be `@validator('<field_name>', ...)`"
|
||||
)
|
||||
elif not all(isinstance(field, str) for field in fields):
|
||||
raise ConfigError(
|
||||
"validator fields should be passed as separate string args. " # noqa: Q000
|
||||
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`"
|
||||
)
|
||||
|
||||
if whole is not None:
|
||||
warnings.warn(
|
||||
'The "whole" keyword argument is deprecated, use "each_item" (inverse meaning, default False) instead',
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert each_item is False, '"each_item" and "whole" conflict, remove "whole"'
|
||||
each_item = not whole
|
||||
|
||||
def dec(f: AnyCallable) -> 'AnyClassMethod':
|
||||
f_cls = _prepare_validator(f, allow_reuse)
|
||||
setattr(
|
||||
f_cls,
|
||||
VALIDATOR_CONFIG_KEY,
|
||||
(
|
||||
fields,
|
||||
Validator(func=f_cls.__func__, pre=pre, each_item=each_item, always=always, check_fields=check_fields),
|
||||
),
|
||||
)
|
||||
return f_cls
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(_func: AnyCallable) -> 'AnyClassMethod':
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(
|
||||
*, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
|
||||
) -> Callable[[AnyCallable], 'AnyClassMethod']:
|
||||
...
|
||||
|
||||
|
||||
def root_validator(
|
||||
_func: Optional[AnyCallable] = None, *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
|
||||
) -> Union['AnyClassMethod', Callable[[AnyCallable], 'AnyClassMethod']]:
|
||||
"""
|
||||
Decorate methods on a model indicating that they should be used to validate (and perhaps modify) data either
|
||||
before or after standard model parsing/validation is performed.
|
||||
"""
|
||||
if _func:
|
||||
f_cls = _prepare_validator(_func, allow_reuse)
|
||||
setattr(
|
||||
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
|
||||
)
|
||||
return f_cls
|
||||
|
||||
def dec(f: AnyCallable) -> 'AnyClassMethod':
|
||||
f_cls = _prepare_validator(f, allow_reuse)
|
||||
setattr(
|
||||
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
|
||||
)
|
||||
return f_cls
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def _prepare_validator(function: AnyCallable, allow_reuse: bool) -> 'AnyClassMethod':
|
||||
"""
|
||||
Avoid validators with duplicated names since without this, validators can be overwritten silently
|
||||
which generally isn't the intended behaviour, don't run in ipython (see #312) or if allow_reuse is False.
|
||||
"""
|
||||
f_cls = function if isinstance(function, classmethod) else classmethod(function)
|
||||
if not in_ipython() and not allow_reuse:
|
||||
ref = (
|
||||
getattr(f_cls.__func__, '__module__', '<No __module__>')
|
||||
+ '.'
|
||||
+ getattr(f_cls.__func__, '__qualname__', f'<No __qualname__: id:{id(f_cls.__func__)}>')
|
||||
)
|
||||
if ref in _FUNCS:
|
||||
raise ConfigError(f'duplicate validator function "{ref}"; if this is intended, set `allow_reuse=True`')
|
||||
_FUNCS.add(ref)
|
||||
return f_cls
|
||||
|
||||
|
||||
class ValidatorGroup:
|
||||
def __init__(self, validators: 'ValidatorListDict') -> None:
|
||||
self.validators = validators
|
||||
self.used_validators = {'*'}
|
||||
|
||||
def get_validators(self, name: str) -> Optional[Dict[str, Validator]]:
|
||||
self.used_validators.add(name)
|
||||
validators = self.validators.get(name, [])
|
||||
if name != ROOT_KEY:
|
||||
validators += self.validators.get('*', [])
|
||||
if validators:
|
||||
return {getattr(v.func, '__name__', f'<No __name__: id:{id(v.func)}>'): v for v in validators}
|
||||
else:
|
||||
return None
|
||||
|
||||
def check_for_unused(self) -> None:
|
||||
unused_validators = set(
|
||||
chain.from_iterable(
|
||||
(
|
||||
getattr(v.func, '__name__', f'<No __name__: id:{id(v.func)}>')
|
||||
for v in self.validators[f]
|
||||
if v.check_fields
|
||||
)
|
||||
for f in (self.validators.keys() - self.used_validators)
|
||||
)
|
||||
)
|
||||
if unused_validators:
|
||||
fn = ', '.join(unused_validators)
|
||||
raise ConfigError(
|
||||
f"Validators defined with incorrect fields: {fn} " # noqa: Q000
|
||||
f"(use check_fields=False if you're inheriting from the model and intended this)"
|
||||
)
|
||||
|
||||
|
||||
def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]:
|
||||
validators: Dict[str, List[Validator]] = {}
|
||||
for var_name, value in namespace.items():
|
||||
validator_config = getattr(value, VALIDATOR_CONFIG_KEY, None)
|
||||
if validator_config:
|
||||
fields, v = validator_config
|
||||
for field in fields:
|
||||
if field in validators:
|
||||
validators[field].append(v)
|
||||
else:
|
||||
validators[field] = [v]
|
||||
return validators
|
||||
|
||||
|
||||
def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable], List[Tuple[bool, AnyCallable]]]:
|
||||
from inspect import signature
|
||||
|
||||
pre_validators: List[AnyCallable] = []
|
||||
post_validators: List[Tuple[bool, AnyCallable]] = []
|
||||
for name, value in namespace.items():
|
||||
validator_config: Optional[Validator] = getattr(value, ROOT_VALIDATOR_CONFIG_KEY, None)
|
||||
if validator_config:
|
||||
sig = signature(validator_config.func)
|
||||
args = list(sig.parameters.keys())
|
||||
if args[0] == 'self':
|
||||
raise ConfigError(
|
||||
f'Invalid signature for root validator {name}: {sig}, "self" not permitted as first argument, '
|
||||
f'should be: (cls, values).'
|
||||
)
|
||||
if len(args) != 2:
|
||||
raise ConfigError(f'Invalid signature for root validator {name}: {sig}, should be: (cls, values).')
|
||||
# check function signature
|
||||
if validator_config.pre:
|
||||
pre_validators.append(validator_config.func)
|
||||
else:
|
||||
post_validators.append((validator_config.skip_on_failure, validator_config.func))
|
||||
return pre_validators, post_validators
|
||||
|
||||
|
||||
def inherit_validators(base_validators: 'ValidatorListDict', validators: 'ValidatorListDict') -> 'ValidatorListDict':
|
||||
for field, field_validators in base_validators.items():
|
||||
if field not in validators:
|
||||
validators[field] = []
|
||||
validators[field] += field_validators
|
||||
return validators
|
||||
|
||||
|
||||
def make_generic_validator(validator: AnyCallable) -> 'ValidatorCallable':
|
||||
"""
|
||||
Make a generic function which calls a validator with the right arguments.
|
||||
|
||||
Unfortunately other approaches (eg. return a partial of a function that builds the arguments) is slow,
|
||||
hence this laborious way of doing things.
|
||||
|
||||
It's done like this so validators don't all need **kwargs in their signature, eg. any combination of
|
||||
the arguments "values", "fields" and/or "config" are permitted.
|
||||
"""
|
||||
from inspect import signature
|
||||
|
||||
if not isinstance(validator, (partial, partialmethod)):
|
||||
# This should be the default case, so overhead is reduced
|
||||
sig = signature(validator)
|
||||
args = list(sig.parameters.keys())
|
||||
else:
|
||||
# Fix the generated argument lists of partial methods
|
||||
sig = signature(validator.func)
|
||||
args = [
|
||||
k
|
||||
for k in signature(validator.func).parameters.keys()
|
||||
if k not in validator.args | validator.keywords.keys()
|
||||
]
|
||||
|
||||
first_arg = args.pop(0)
|
||||
if first_arg == 'self':
|
||||
raise ConfigError(
|
||||
f'Invalid signature for validator {validator}: {sig}, "self" not permitted as first argument, '
|
||||
f'should be: (cls, value, values, config, field), "values", "config" and "field" are all optional.'
|
||||
)
|
||||
elif first_arg == 'cls':
|
||||
# assume the second argument is value
|
||||
return wraps(validator)(_generic_validator_cls(validator, sig, set(args[1:])))
|
||||
else:
|
||||
# assume the first argument was value which has already been removed
|
||||
return wraps(validator)(_generic_validator_basic(validator, sig, set(args)))
|
||||
|
||||
|
||||
def prep_validators(v_funcs: Iterable[AnyCallable]) -> 'ValidatorsList':
|
||||
return [make_generic_validator(f) for f in v_funcs if f]
|
||||
|
||||
|
||||
all_kwargs = {'values', 'field', 'config'}
|
||||
|
||||
|
||||
def _generic_validator_cls(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
|
||||
# assume the first argument is value
|
||||
has_kwargs = False
|
||||
if 'kwargs' in args:
|
||||
has_kwargs = True
|
||||
args -= {'kwargs'}
|
||||
|
||||
if not args.issubset(all_kwargs):
|
||||
raise ConfigError(
|
||||
f'Invalid signature for validator {validator}: {sig}, should be: '
|
||||
f'(cls, value, values, config, field), "values", "config" and "field" are all optional.'
|
||||
)
|
||||
|
||||
if has_kwargs:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
|
||||
elif args == set():
|
||||
return lambda cls, v, values, field, config: validator(cls, v)
|
||||
elif args == {'values'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values)
|
||||
elif args == {'field'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, field=field)
|
||||
elif args == {'config'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, config=config)
|
||||
elif args == {'values', 'field'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field)
|
||||
elif args == {'values', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, config=config)
|
||||
elif args == {'field', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, field=field, config=config)
|
||||
else:
|
||||
# args == {'values', 'field', 'config'}
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
|
||||
|
||||
|
||||
def _generic_validator_basic(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
|
||||
has_kwargs = False
|
||||
if 'kwargs' in args:
|
||||
has_kwargs = True
|
||||
args -= {'kwargs'}
|
||||
|
||||
if not args.issubset(all_kwargs):
|
||||
raise ConfigError(
|
||||
f'Invalid signature for validator {validator}: {sig}, should be: '
|
||||
f'(value, values, config, field), "values", "config" and "field" are all optional.'
|
||||
)
|
||||
|
||||
if has_kwargs:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
|
||||
elif args == set():
|
||||
return lambda cls, v, values, field, config: validator(v)
|
||||
elif args == {'values'}:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values)
|
||||
elif args == {'field'}:
|
||||
return lambda cls, v, values, field, config: validator(v, field=field)
|
||||
elif args == {'config'}:
|
||||
return lambda cls, v, values, field, config: validator(v, config=config)
|
||||
elif args == {'values', 'field'}:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, field=field)
|
||||
elif args == {'values', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, config=config)
|
||||
elif args == {'field', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(v, field=field, config=config)
|
||||
else:
|
||||
# args == {'values', 'field', 'config'}
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
|
||||
|
||||
|
||||
def gather_all_validators(type_: 'ModelOrDc') -> Dict[str, 'AnyClassMethod']:
|
||||
all_attributes = ChainMap(*[cls.__dict__ for cls in type_.__mro__]) # type: ignore[arg-type,var-annotated]
|
||||
return {
|
||||
k: v
|
||||
for k, v in all_attributes.items()
|
||||
if hasattr(v, VALIDATOR_CONFIG_KEY) or hasattr(v, ROOT_VALIDATOR_CONFIG_KEY)
|
||||
}
|
||||
494
venv/lib/python3.11/site-packages/pydantic/v1/color.py
Normal file
494
venv/lib/python3.11/site-packages/pydantic/v1/color.py
Normal file
@@ -0,0 +1,494 @@
|
||||
"""
|
||||
Color definitions are used as per CSS3 specification:
|
||||
http://www.w3.org/TR/css3-color/#svg-color
|
||||
|
||||
A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`.
|
||||
|
||||
In these cases the LAST color when sorted alphabetically takes preferences,
|
||||
eg. Color((0, 255, 255)).as_named() == 'cyan' because "cyan" comes after "aqua".
|
||||
"""
|
||||
import math
|
||||
import re
|
||||
from colorsys import hls_to_rgb, rgb_to_hls
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast
|
||||
|
||||
from pydantic.v1.errors import ColorError
|
||||
from pydantic.v1.utils import Representation, almost_equal_floats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import CallableGenerator, ReprArgs
|
||||
|
||||
ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]]
|
||||
ColorType = Union[ColorTuple, str]
|
||||
HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]]
|
||||
|
||||
|
||||
class RGBA:
|
||||
"""
|
||||
Internal use only as a representation of a color.
|
||||
"""
|
||||
|
||||
__slots__ = 'r', 'g', 'b', 'alpha', '_tuple'
|
||||
|
||||
def __init__(self, r: float, g: float, b: float, alpha: Optional[float]):
|
||||
self.r = r
|
||||
self.g = g
|
||||
self.b = b
|
||||
self.alpha = alpha
|
||||
|
||||
self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha)
|
||||
|
||||
def __getitem__(self, item: Any) -> Any:
|
||||
return self._tuple[item]
|
||||
|
||||
|
||||
# these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached
|
||||
r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*'
|
||||
r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*'
|
||||
_r_255 = r'(\d{1,3}(?:\.\d+)?)'
|
||||
_r_comma = r'\s*,\s*'
|
||||
r_rgb = fr'\s*rgb\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}\)\s*'
|
||||
_r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)'
|
||||
r_rgba = fr'\s*rgba\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_alpha}\s*\)\s*'
|
||||
_r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?'
|
||||
_r_sl = r'(\d{1,3}(?:\.\d+)?)%'
|
||||
r_hsl = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}\s*\)\s*'
|
||||
r_hsla = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}{_r_comma}{_r_alpha}\s*\)\s*'
|
||||
|
||||
# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used
|
||||
repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'}
|
||||
rads = 2 * math.pi
|
||||
|
||||
|
||||
class Color(Representation):
|
||||
__slots__ = '_original', '_rgba'
|
||||
|
||||
def __init__(self, value: ColorType) -> None:
|
||||
self._rgba: RGBA
|
||||
self._original: ColorType
|
||||
if isinstance(value, (tuple, list)):
|
||||
self._rgba = parse_tuple(value)
|
||||
elif isinstance(value, str):
|
||||
self._rgba = parse_str(value)
|
||||
elif isinstance(value, Color):
|
||||
self._rgba = value._rgba
|
||||
value = value._original
|
||||
else:
|
||||
raise ColorError(reason='value must be a tuple, list or string')
|
||||
|
||||
# if we've got here value must be a valid color
|
||||
self._original = value
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='color')
|
||||
|
||||
def original(self) -> ColorType:
|
||||
"""
|
||||
Original value passed to Color
|
||||
"""
|
||||
return self._original
|
||||
|
||||
def as_named(self, *, fallback: bool = False) -> str:
|
||||
if self._rgba.alpha is None:
|
||||
rgb = cast(Tuple[int, int, int], self.as_rgb_tuple())
|
||||
try:
|
||||
return COLORS_BY_VALUE[rgb]
|
||||
except KeyError as e:
|
||||
if fallback:
|
||||
return self.as_hex()
|
||||
else:
|
||||
raise ValueError('no named color found, use fallback=True, as_hex() or as_rgb()') from e
|
||||
else:
|
||||
return self.as_hex()
|
||||
|
||||
def as_hex(self) -> str:
|
||||
"""
|
||||
Hex string representing the color can be 3, 4, 6 or 8 characters depending on whether the string
|
||||
a "short" representation of the color is possible and whether there's an alpha channel.
|
||||
"""
|
||||
values = [float_to_255(c) for c in self._rgba[:3]]
|
||||
if self._rgba.alpha is not None:
|
||||
values.append(float_to_255(self._rgba.alpha))
|
||||
|
||||
as_hex = ''.join(f'{v:02x}' for v in values)
|
||||
if all(c in repeat_colors for c in values):
|
||||
as_hex = ''.join(as_hex[c] for c in range(0, len(as_hex), 2))
|
||||
return '#' + as_hex
|
||||
|
||||
def as_rgb(self) -> str:
|
||||
"""
|
||||
Color as an rgb(<r>, <g>, <b>) or rgba(<r>, <g>, <b>, <a>) string.
|
||||
"""
|
||||
if self._rgba.alpha is None:
|
||||
return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})'
|
||||
else:
|
||||
return (
|
||||
f'rgba({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)}, '
|
||||
f'{round(self._alpha_float(), 2)})'
|
||||
)
|
||||
|
||||
def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple:
|
||||
"""
|
||||
Color as an RGB or RGBA tuple; red, green and blue are in the range 0 to 255, alpha if included is
|
||||
in the range 0 to 1.
|
||||
|
||||
:param alpha: whether to include the alpha channel, options are
|
||||
None - (default) include alpha only if it's set (e.g. not None)
|
||||
True - always include alpha,
|
||||
False - always omit alpha,
|
||||
"""
|
||||
r, g, b = (float_to_255(c) for c in self._rgba[:3])
|
||||
if alpha is None:
|
||||
if self._rgba.alpha is None:
|
||||
return r, g, b
|
||||
else:
|
||||
return r, g, b, self._alpha_float()
|
||||
elif alpha:
|
||||
return r, g, b, self._alpha_float()
|
||||
else:
|
||||
# alpha is False
|
||||
return r, g, b
|
||||
|
||||
def as_hsl(self) -> str:
|
||||
"""
|
||||
Color as an hsl(<h>, <s>, <l>) or hsl(<h>, <s>, <l>, <a>) string.
|
||||
"""
|
||||
if self._rgba.alpha is None:
|
||||
h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore
|
||||
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})'
|
||||
else:
|
||||
h, s, li, a = self.as_hsl_tuple(alpha=True) # type: ignore
|
||||
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})'
|
||||
|
||||
def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple:
|
||||
"""
|
||||
Color as an HSL or HSLA tuple, e.g. hue, saturation, lightness and optionally alpha; all elements are in
|
||||
the range 0 to 1.
|
||||
|
||||
NOTE: this is HSL as used in HTML and most other places, not HLS as used in python's colorsys.
|
||||
|
||||
:param alpha: whether to include the alpha channel, options are
|
||||
None - (default) include alpha only if it's set (e.g. not None)
|
||||
True - always include alpha,
|
||||
False - always omit alpha,
|
||||
"""
|
||||
h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b)
|
||||
if alpha is None:
|
||||
if self._rgba.alpha is None:
|
||||
return h, s, l
|
||||
else:
|
||||
return h, s, l, self._alpha_float()
|
||||
if alpha:
|
||||
return h, s, l, self._alpha_float()
|
||||
else:
|
||||
# alpha is False
|
||||
return h, s, l
|
||||
|
||||
def _alpha_float(self) -> float:
|
||||
return 1 if self._rgba.alpha is None else self._rgba.alpha
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.as_named(fallback=True)
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())] # type: ignore
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.as_rgb_tuple())
|
||||
|
||||
|
||||
def parse_tuple(value: Tuple[Any, ...]) -> RGBA:
|
||||
"""
|
||||
Parse a tuple or list as a color.
|
||||
"""
|
||||
if len(value) == 3:
|
||||
r, g, b = (parse_color_value(v) for v in value)
|
||||
return RGBA(r, g, b, None)
|
||||
elif len(value) == 4:
|
||||
r, g, b = (parse_color_value(v) for v in value[:3])
|
||||
return RGBA(r, g, b, parse_float_alpha(value[3]))
|
||||
else:
|
||||
raise ColorError(reason='tuples must have length 3 or 4')
|
||||
|
||||
|
||||
def parse_str(value: str) -> RGBA:
|
||||
"""
|
||||
Parse a string to an RGBA tuple, trying the following formats (in this order):
|
||||
* named color, see COLORS_BY_NAME below
|
||||
* hex short eg. `<prefix>fff` (prefix can be `#`, `0x` or nothing)
|
||||
* hex long eg. `<prefix>ffffff` (prefix can be `#`, `0x` or nothing)
|
||||
* `rgb(<r>, <g>, <b>) `
|
||||
* `rgba(<r>, <g>, <b>, <a>)`
|
||||
"""
|
||||
value_lower = value.lower()
|
||||
try:
|
||||
r, g, b = COLORS_BY_NAME[value_lower]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return ints_to_rgba(r, g, b, None)
|
||||
|
||||
m = re.fullmatch(r_hex_short, value_lower)
|
||||
if m:
|
||||
*rgb, a = m.groups()
|
||||
r, g, b = (int(v * 2, 16) for v in rgb)
|
||||
if a:
|
||||
alpha: Optional[float] = int(a * 2, 16) / 255
|
||||
else:
|
||||
alpha = None
|
||||
return ints_to_rgba(r, g, b, alpha)
|
||||
|
||||
m = re.fullmatch(r_hex_long, value_lower)
|
||||
if m:
|
||||
*rgb, a = m.groups()
|
||||
r, g, b = (int(v, 16) for v in rgb)
|
||||
if a:
|
||||
alpha = int(a, 16) / 255
|
||||
else:
|
||||
alpha = None
|
||||
return ints_to_rgba(r, g, b, alpha)
|
||||
|
||||
m = re.fullmatch(r_rgb, value_lower)
|
||||
if m:
|
||||
return ints_to_rgba(*m.groups(), None) # type: ignore
|
||||
|
||||
m = re.fullmatch(r_rgba, value_lower)
|
||||
if m:
|
||||
return ints_to_rgba(*m.groups()) # type: ignore
|
||||
|
||||
m = re.fullmatch(r_hsl, value_lower)
|
||||
if m:
|
||||
h, h_units, s, l_ = m.groups()
|
||||
return parse_hsl(h, h_units, s, l_)
|
||||
|
||||
m = re.fullmatch(r_hsla, value_lower)
|
||||
if m:
|
||||
h, h_units, s, l_, a = m.groups()
|
||||
return parse_hsl(h, h_units, s, l_, parse_float_alpha(a))
|
||||
|
||||
raise ColorError(reason='string not recognised as a valid color')
|
||||
|
||||
|
||||
def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float]) -> RGBA:
|
||||
return RGBA(parse_color_value(r), parse_color_value(g), parse_color_value(b), parse_float_alpha(alpha))
|
||||
|
||||
|
||||
def parse_color_value(value: Union[int, str], max_val: int = 255) -> float:
|
||||
"""
|
||||
Parse a value checking it's a valid int in the range 0 to max_val and divide by max_val to give a number
|
||||
in the range 0 to 1
|
||||
"""
|
||||
try:
|
||||
color = float(value)
|
||||
except ValueError:
|
||||
raise ColorError(reason='color values must be a valid number')
|
||||
if 0 <= color <= max_val:
|
||||
return color / max_val
|
||||
else:
|
||||
raise ColorError(reason=f'color values must be in the range 0 to {max_val}')
|
||||
|
||||
|
||||
def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]:
|
||||
"""
|
||||
Parse a value checking it's a valid float in the range 0 to 1
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
if isinstance(value, str) and value.endswith('%'):
|
||||
alpha = float(value[:-1]) / 100
|
||||
else:
|
||||
alpha = float(value)
|
||||
except ValueError:
|
||||
raise ColorError(reason='alpha values must be a valid float')
|
||||
|
||||
if almost_equal_floats(alpha, 1):
|
||||
return None
|
||||
elif 0 <= alpha <= 1:
|
||||
return alpha
|
||||
else:
|
||||
raise ColorError(reason='alpha values must be in the range 0 to 1')
|
||||
|
||||
|
||||
def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] = None) -> RGBA:
|
||||
"""
|
||||
Parse raw hue, saturation, lightness and alpha values and convert to RGBA.
|
||||
"""
|
||||
s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100)
|
||||
|
||||
h_value = float(h)
|
||||
if h_units in {None, 'deg'}:
|
||||
h_value = h_value % 360 / 360
|
||||
elif h_units == 'rad':
|
||||
h_value = h_value % rads / rads
|
||||
else:
|
||||
# turns
|
||||
h_value = h_value % 1
|
||||
|
||||
r, g, b = hls_to_rgb(h_value, l_value, s_value)
|
||||
return RGBA(r, g, b, alpha)
|
||||
|
||||
|
||||
def float_to_255(c: float) -> int:
|
||||
return int(round(c * 255))
|
||||
|
||||
|
||||
COLORS_BY_NAME = {
|
||||
'aliceblue': (240, 248, 255),
|
||||
'antiquewhite': (250, 235, 215),
|
||||
'aqua': (0, 255, 255),
|
||||
'aquamarine': (127, 255, 212),
|
||||
'azure': (240, 255, 255),
|
||||
'beige': (245, 245, 220),
|
||||
'bisque': (255, 228, 196),
|
||||
'black': (0, 0, 0),
|
||||
'blanchedalmond': (255, 235, 205),
|
||||
'blue': (0, 0, 255),
|
||||
'blueviolet': (138, 43, 226),
|
||||
'brown': (165, 42, 42),
|
||||
'burlywood': (222, 184, 135),
|
||||
'cadetblue': (95, 158, 160),
|
||||
'chartreuse': (127, 255, 0),
|
||||
'chocolate': (210, 105, 30),
|
||||
'coral': (255, 127, 80),
|
||||
'cornflowerblue': (100, 149, 237),
|
||||
'cornsilk': (255, 248, 220),
|
||||
'crimson': (220, 20, 60),
|
||||
'cyan': (0, 255, 255),
|
||||
'darkblue': (0, 0, 139),
|
||||
'darkcyan': (0, 139, 139),
|
||||
'darkgoldenrod': (184, 134, 11),
|
||||
'darkgray': (169, 169, 169),
|
||||
'darkgreen': (0, 100, 0),
|
||||
'darkgrey': (169, 169, 169),
|
||||
'darkkhaki': (189, 183, 107),
|
||||
'darkmagenta': (139, 0, 139),
|
||||
'darkolivegreen': (85, 107, 47),
|
||||
'darkorange': (255, 140, 0),
|
||||
'darkorchid': (153, 50, 204),
|
||||
'darkred': (139, 0, 0),
|
||||
'darksalmon': (233, 150, 122),
|
||||
'darkseagreen': (143, 188, 143),
|
||||
'darkslateblue': (72, 61, 139),
|
||||
'darkslategray': (47, 79, 79),
|
||||
'darkslategrey': (47, 79, 79),
|
||||
'darkturquoise': (0, 206, 209),
|
||||
'darkviolet': (148, 0, 211),
|
||||
'deeppink': (255, 20, 147),
|
||||
'deepskyblue': (0, 191, 255),
|
||||
'dimgray': (105, 105, 105),
|
||||
'dimgrey': (105, 105, 105),
|
||||
'dodgerblue': (30, 144, 255),
|
||||
'firebrick': (178, 34, 34),
|
||||
'floralwhite': (255, 250, 240),
|
||||
'forestgreen': (34, 139, 34),
|
||||
'fuchsia': (255, 0, 255),
|
||||
'gainsboro': (220, 220, 220),
|
||||
'ghostwhite': (248, 248, 255),
|
||||
'gold': (255, 215, 0),
|
||||
'goldenrod': (218, 165, 32),
|
||||
'gray': (128, 128, 128),
|
||||
'green': (0, 128, 0),
|
||||
'greenyellow': (173, 255, 47),
|
||||
'grey': (128, 128, 128),
|
||||
'honeydew': (240, 255, 240),
|
||||
'hotpink': (255, 105, 180),
|
||||
'indianred': (205, 92, 92),
|
||||
'indigo': (75, 0, 130),
|
||||
'ivory': (255, 255, 240),
|
||||
'khaki': (240, 230, 140),
|
||||
'lavender': (230, 230, 250),
|
||||
'lavenderblush': (255, 240, 245),
|
||||
'lawngreen': (124, 252, 0),
|
||||
'lemonchiffon': (255, 250, 205),
|
||||
'lightblue': (173, 216, 230),
|
||||
'lightcoral': (240, 128, 128),
|
||||
'lightcyan': (224, 255, 255),
|
||||
'lightgoldenrodyellow': (250, 250, 210),
|
||||
'lightgray': (211, 211, 211),
|
||||
'lightgreen': (144, 238, 144),
|
||||
'lightgrey': (211, 211, 211),
|
||||
'lightpink': (255, 182, 193),
|
||||
'lightsalmon': (255, 160, 122),
|
||||
'lightseagreen': (32, 178, 170),
|
||||
'lightskyblue': (135, 206, 250),
|
||||
'lightslategray': (119, 136, 153),
|
||||
'lightslategrey': (119, 136, 153),
|
||||
'lightsteelblue': (176, 196, 222),
|
||||
'lightyellow': (255, 255, 224),
|
||||
'lime': (0, 255, 0),
|
||||
'limegreen': (50, 205, 50),
|
||||
'linen': (250, 240, 230),
|
||||
'magenta': (255, 0, 255),
|
||||
'maroon': (128, 0, 0),
|
||||
'mediumaquamarine': (102, 205, 170),
|
||||
'mediumblue': (0, 0, 205),
|
||||
'mediumorchid': (186, 85, 211),
|
||||
'mediumpurple': (147, 112, 219),
|
||||
'mediumseagreen': (60, 179, 113),
|
||||
'mediumslateblue': (123, 104, 238),
|
||||
'mediumspringgreen': (0, 250, 154),
|
||||
'mediumturquoise': (72, 209, 204),
|
||||
'mediumvioletred': (199, 21, 133),
|
||||
'midnightblue': (25, 25, 112),
|
||||
'mintcream': (245, 255, 250),
|
||||
'mistyrose': (255, 228, 225),
|
||||
'moccasin': (255, 228, 181),
|
||||
'navajowhite': (255, 222, 173),
|
||||
'navy': (0, 0, 128),
|
||||
'oldlace': (253, 245, 230),
|
||||
'olive': (128, 128, 0),
|
||||
'olivedrab': (107, 142, 35),
|
||||
'orange': (255, 165, 0),
|
||||
'orangered': (255, 69, 0),
|
||||
'orchid': (218, 112, 214),
|
||||
'palegoldenrod': (238, 232, 170),
|
||||
'palegreen': (152, 251, 152),
|
||||
'paleturquoise': (175, 238, 238),
|
||||
'palevioletred': (219, 112, 147),
|
||||
'papayawhip': (255, 239, 213),
|
||||
'peachpuff': (255, 218, 185),
|
||||
'peru': (205, 133, 63),
|
||||
'pink': (255, 192, 203),
|
||||
'plum': (221, 160, 221),
|
||||
'powderblue': (176, 224, 230),
|
||||
'purple': (128, 0, 128),
|
||||
'red': (255, 0, 0),
|
||||
'rosybrown': (188, 143, 143),
|
||||
'royalblue': (65, 105, 225),
|
||||
'saddlebrown': (139, 69, 19),
|
||||
'salmon': (250, 128, 114),
|
||||
'sandybrown': (244, 164, 96),
|
||||
'seagreen': (46, 139, 87),
|
||||
'seashell': (255, 245, 238),
|
||||
'sienna': (160, 82, 45),
|
||||
'silver': (192, 192, 192),
|
||||
'skyblue': (135, 206, 235),
|
||||
'slateblue': (106, 90, 205),
|
||||
'slategray': (112, 128, 144),
|
||||
'slategrey': (112, 128, 144),
|
||||
'snow': (255, 250, 250),
|
||||
'springgreen': (0, 255, 127),
|
||||
'steelblue': (70, 130, 180),
|
||||
'tan': (210, 180, 140),
|
||||
'teal': (0, 128, 128),
|
||||
'thistle': (216, 191, 216),
|
||||
'tomato': (255, 99, 71),
|
||||
'turquoise': (64, 224, 208),
|
||||
'violet': (238, 130, 238),
|
||||
'wheat': (245, 222, 179),
|
||||
'white': (255, 255, 255),
|
||||
'whitesmoke': (245, 245, 245),
|
||||
'yellow': (255, 255, 0),
|
||||
'yellowgreen': (154, 205, 50),
|
||||
}
|
||||
|
||||
COLORS_BY_VALUE = {v: k for k, v in COLORS_BY_NAME.items()}
|
||||
191
venv/lib/python3.11/site-packages/pydantic/v1/config.py
Normal file
191
venv/lib/python3.11/site-packages/pydantic/v1/config.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tuple, Type, Union
|
||||
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
from pydantic.v1.typing import AnyArgTCallable, AnyCallable
|
||||
from pydantic.v1.utils import GetterDict
|
||||
from pydantic.v1.version import compiled
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import overload
|
||||
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.main import BaseModel
|
||||
|
||||
ConfigType = Type['BaseConfig']
|
||||
|
||||
class SchemaExtraCallable(Protocol):
|
||||
@overload
|
||||
def __call__(self, schema: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def __call__(self, schema: Dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
pass
|
||||
|
||||
else:
|
||||
SchemaExtraCallable = Callable[..., None]
|
||||
|
||||
__all__ = 'BaseConfig', 'ConfigDict', 'get_config', 'Extra', 'inherit_config', 'prepare_config'
|
||||
|
||||
|
||||
class Extra(str, Enum):
|
||||
allow = 'allow'
|
||||
ignore = 'ignore'
|
||||
forbid = 'forbid'
|
||||
|
||||
|
||||
# https://github.com/cython/cython/issues/4003
|
||||
# Fixed in Cython 3 and Pydantic v1 won't support Cython 3.
|
||||
# Pydantic v2 doesn't depend on Cython at all.
|
||||
if not compiled:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
class ConfigDict(TypedDict, total=False):
|
||||
title: Optional[str]
|
||||
anystr_lower: bool
|
||||
anystr_strip_whitespace: bool
|
||||
min_anystr_length: int
|
||||
max_anystr_length: Optional[int]
|
||||
validate_all: bool
|
||||
extra: Extra
|
||||
allow_mutation: bool
|
||||
frozen: bool
|
||||
allow_population_by_field_name: bool
|
||||
use_enum_values: bool
|
||||
fields: Dict[str, Union[str, Dict[str, str]]]
|
||||
validate_assignment: bool
|
||||
error_msg_templates: Dict[str, str]
|
||||
arbitrary_types_allowed: bool
|
||||
orm_mode: bool
|
||||
getter_dict: Type[GetterDict]
|
||||
alias_generator: Optional[Callable[[str], str]]
|
||||
keep_untouched: Tuple[type, ...]
|
||||
schema_extra: Union[Dict[str, object], 'SchemaExtraCallable']
|
||||
json_loads: Callable[[str], object]
|
||||
json_dumps: AnyArgTCallable[str]
|
||||
json_encoders: Dict[Type[object], AnyCallable]
|
||||
underscore_attrs_are_private: bool
|
||||
allow_inf_nan: bool
|
||||
copy_on_model_validation: Literal['none', 'deep', 'shallow']
|
||||
# whether dataclass `__post_init__` should be run after validation
|
||||
post_init_call: Literal['before_validation', 'after_validation']
|
||||
|
||||
else:
|
||||
ConfigDict = dict # type: ignore
|
||||
|
||||
|
||||
class BaseConfig:
|
||||
title: Optional[str] = None
|
||||
anystr_lower: bool = False
|
||||
anystr_upper: bool = False
|
||||
anystr_strip_whitespace: bool = False
|
||||
min_anystr_length: int = 0
|
||||
max_anystr_length: Optional[int] = None
|
||||
validate_all: bool = False
|
||||
extra: Extra = Extra.ignore
|
||||
allow_mutation: bool = True
|
||||
frozen: bool = False
|
||||
allow_population_by_field_name: bool = False
|
||||
use_enum_values: bool = False
|
||||
fields: Dict[str, Union[str, Dict[str, str]]] = {}
|
||||
validate_assignment: bool = False
|
||||
error_msg_templates: Dict[str, str] = {}
|
||||
arbitrary_types_allowed: bool = False
|
||||
orm_mode: bool = False
|
||||
getter_dict: Type[GetterDict] = GetterDict
|
||||
alias_generator: Optional[Callable[[str], str]] = None
|
||||
keep_untouched: Tuple[type, ...] = ()
|
||||
schema_extra: Union[Dict[str, Any], 'SchemaExtraCallable'] = {}
|
||||
json_loads: Callable[[str], Any] = json.loads
|
||||
json_dumps: Callable[..., str] = json.dumps
|
||||
json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable] = {}
|
||||
underscore_attrs_are_private: bool = False
|
||||
allow_inf_nan: bool = True
|
||||
|
||||
# whether inherited models as fields should be reconstructed as base model,
|
||||
# and whether such a copy should be shallow or deep
|
||||
copy_on_model_validation: Literal['none', 'deep', 'shallow'] = 'shallow'
|
||||
|
||||
# whether `Union` should check all allowed types before even trying to coerce
|
||||
smart_union: bool = False
|
||||
# whether dataclass `__post_init__` should be run before or after validation
|
||||
post_init_call: Literal['before_validation', 'after_validation'] = 'before_validation'
|
||||
|
||||
@classmethod
|
||||
def get_field_info(cls, name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get properties of FieldInfo from the `fields` property of the config class.
|
||||
"""
|
||||
|
||||
fields_value = cls.fields.get(name)
|
||||
|
||||
if isinstance(fields_value, str):
|
||||
field_info: Dict[str, Any] = {'alias': fields_value}
|
||||
elif isinstance(fields_value, dict):
|
||||
field_info = fields_value
|
||||
else:
|
||||
field_info = {}
|
||||
|
||||
if 'alias' in field_info:
|
||||
field_info.setdefault('alias_priority', 2)
|
||||
|
||||
if field_info.get('alias_priority', 0) <= 1 and cls.alias_generator:
|
||||
alias = cls.alias_generator(name)
|
||||
if not isinstance(alias, str):
|
||||
raise TypeError(f'Config.alias_generator must return str, not {alias.__class__}')
|
||||
field_info.update(alias=alias, alias_priority=1)
|
||||
return field_info
|
||||
|
||||
@classmethod
|
||||
def prepare_field(cls, field: 'ModelField') -> None:
|
||||
"""
|
||||
Optional hook to check or modify fields during model creation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_config(config: Union[ConfigDict, Type[object], None]) -> Type[BaseConfig]:
|
||||
if config is None:
|
||||
return BaseConfig
|
||||
|
||||
else:
|
||||
config_dict = (
|
||||
config
|
||||
if isinstance(config, dict)
|
||||
else {k: getattr(config, k) for k in dir(config) if not k.startswith('__')}
|
||||
)
|
||||
|
||||
class Config(BaseConfig):
|
||||
...
|
||||
|
||||
for k, v in config_dict.items():
|
||||
setattr(Config, k, v)
|
||||
return Config
|
||||
|
||||
|
||||
def inherit_config(self_config: 'ConfigType', parent_config: 'ConfigType', **namespace: Any) -> 'ConfigType':
|
||||
if not self_config:
|
||||
base_classes: Tuple['ConfigType', ...] = (parent_config,)
|
||||
elif self_config == parent_config:
|
||||
base_classes = (self_config,)
|
||||
else:
|
||||
base_classes = self_config, parent_config
|
||||
|
||||
namespace['json_encoders'] = {
|
||||
**getattr(parent_config, 'json_encoders', {}),
|
||||
**getattr(self_config, 'json_encoders', {}),
|
||||
**namespace.get('json_encoders', {}),
|
||||
}
|
||||
|
||||
return type('Config', base_classes, namespace)
|
||||
|
||||
|
||||
def prepare_config(config: Type[BaseConfig], cls_name: str) -> None:
|
||||
if not isinstance(config.extra, Extra):
|
||||
try:
|
||||
config.extra = Extra(config.extra)
|
||||
except ValueError:
|
||||
raise ValueError(f'"{cls_name}": {config.extra} is not a valid value for "extra"')
|
||||
500
venv/lib/python3.11/site-packages/pydantic/v1/dataclasses.py
Normal file
500
venv/lib/python3.11/site-packages/pydantic/v1/dataclasses.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
The main purpose is to enhance stdlib dataclasses by adding validation
|
||||
A pydantic dataclass can be generated from scratch or from a stdlib one.
|
||||
|
||||
Behind the scene, a pydantic dataclass is just like a regular one on which we attach
|
||||
a `BaseModel` and magic methods to trigger the validation of the data.
|
||||
`__init__` and `__post_init__` are hence overridden and have extra logic to be
|
||||
able to validate input data.
|
||||
|
||||
When a pydantic dataclass is generated from scratch, it's just a plain dataclass
|
||||
with validation triggered at initialization
|
||||
|
||||
The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g.
|
||||
|
||||
```py
|
||||
@dataclasses.dataclass
|
||||
class M:
|
||||
x: int
|
||||
|
||||
ValidatedM = pydantic.dataclasses.dataclass(M)
|
||||
```
|
||||
|
||||
We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one!
|
||||
|
||||
```py
|
||||
assert isinstance(ValidatedM(x=1), M)
|
||||
assert ValidatedM(x=1) == M(x=1)
|
||||
```
|
||||
|
||||
This means we **don't want to create a new dataclass that inherits from it**
|
||||
The trick is to create a wrapper around `M` that will act as a proxy to trigger
|
||||
validation without altering default `M` behaviour.
|
||||
"""
|
||||
import copy
|
||||
import dataclasses
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
try:
|
||||
from functools import cached_property
|
||||
except ImportError:
|
||||
# cached_property available only for python3.8+
|
||||
pass
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
|
||||
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
from pydantic.v1.class_validators import gather_all_validators
|
||||
from pydantic.v1.config import BaseConfig, ConfigDict, Extra, get_config
|
||||
from pydantic.v1.error_wrappers import ValidationError
|
||||
from pydantic.v1.errors import DataclassTypeError
|
||||
from pydantic.v1.fields import Field, FieldInfo, Required, Undefined
|
||||
from pydantic.v1.main import create_model, validate_model
|
||||
from pydantic.v1.utils import ClassAttribute
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.main import BaseModel
|
||||
from pydantic.v1.typing import CallableGenerator, NoArgAnyCallable
|
||||
|
||||
DataclassT = TypeVar('DataclassT', bound='Dataclass')
|
||||
|
||||
DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy']
|
||||
|
||||
class Dataclass:
|
||||
# stdlib attributes
|
||||
__dataclass_fields__: ClassVar[Dict[str, Any]]
|
||||
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
|
||||
__post_init__: ClassVar[Callable[..., None]]
|
||||
|
||||
# Added by pydantic
|
||||
__pydantic_run_validation__: ClassVar[bool]
|
||||
__post_init_post_parse__: ClassVar[Callable[..., None]]
|
||||
__pydantic_initialised__: ClassVar[bool]
|
||||
__pydantic_model__: ClassVar[Type[BaseModel]]
|
||||
__pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]]
|
||||
__pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
'dataclass',
|
||||
'set_validation',
|
||||
'create_pydantic_model_from_dataclass',
|
||||
'is_builtin_dataclass',
|
||||
'make_dataclass_validator',
|
||||
]
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
kw_only: bool = ...,
|
||||
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
|
||||
...
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: Type[_T],
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
kw_only: bool = ...,
|
||||
) -> 'DataclassClassOrWrapper':
|
||||
...
|
||||
|
||||
else:
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
|
||||
...
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: Type[_T],
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
) -> 'DataclassClassOrWrapper':
|
||||
...
|
||||
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
def dataclass(
|
||||
_cls: Optional[Type[_T]] = None,
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
use_proxy: Optional[bool] = None,
|
||||
kw_only: bool = False,
|
||||
) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']:
|
||||
"""
|
||||
Like the python standard lib dataclasses but with type validation.
|
||||
The result is either a pydantic dataclass that will validate input data
|
||||
or a wrapper that will trigger validation around a stdlib dataclass
|
||||
to avoid modifying it directly
|
||||
"""
|
||||
the_config = get_config(config)
|
||||
|
||||
def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
|
||||
should_use_proxy = (
|
||||
use_proxy
|
||||
if use_proxy is not None
|
||||
else (
|
||||
is_builtin_dataclass(cls)
|
||||
and (cls.__bases__[0] is object or set(dir(cls)) == set(dir(cls.__bases__[0])))
|
||||
)
|
||||
)
|
||||
if should_use_proxy:
|
||||
dc_cls_doc = ''
|
||||
dc_cls = DataclassProxy(cls)
|
||||
default_validate_on_init = False
|
||||
else:
|
||||
dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass
|
||||
if sys.version_info >= (3, 10):
|
||||
dc_cls = dataclasses.dataclass(
|
||||
cls,
|
||||
init=init,
|
||||
repr=repr,
|
||||
eq=eq,
|
||||
order=order,
|
||||
unsafe_hash=unsafe_hash,
|
||||
frozen=frozen,
|
||||
kw_only=kw_only,
|
||||
)
|
||||
else:
|
||||
dc_cls = dataclasses.dataclass( # type: ignore
|
||||
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
|
||||
)
|
||||
default_validate_on_init = True
|
||||
|
||||
should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init
|
||||
_add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc)
|
||||
dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls})
|
||||
return dc_cls
|
||||
|
||||
if _cls is None:
|
||||
return wrap
|
||||
|
||||
return wrap(_cls)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]:
|
||||
original_run_validation = cls.__pydantic_run_validation__
|
||||
try:
|
||||
cls.__pydantic_run_validation__ = value
|
||||
yield cls
|
||||
finally:
|
||||
cls.__pydantic_run_validation__ = original_run_validation
|
||||
|
||||
|
||||
class DataclassProxy:
|
||||
__slots__ = '__dataclass__'
|
||||
|
||||
def __init__(self, dc_cls: Type['Dataclass']) -> None:
|
||||
object.__setattr__(self, '__dataclass__', dc_cls)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
with set_validation(self.__dataclass__, True):
|
||||
return self.__dataclass__(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.__dataclass__, name)
|
||||
|
||||
def __setattr__(self, __name: str, __value: Any) -> None:
|
||||
return setattr(self.__dataclass__, __name, __value)
|
||||
|
||||
def __instancecheck__(self, instance: Any) -> bool:
|
||||
return isinstance(instance, self.__dataclass__)
|
||||
|
||||
def __copy__(self) -> 'DataclassProxy':
|
||||
return DataclassProxy(copy.copy(self.__dataclass__))
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> 'DataclassProxy':
|
||||
return DataclassProxy(copy.deepcopy(self.__dataclass__, memo))
|
||||
|
||||
|
||||
def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity)
|
||||
dc_cls: Type['Dataclass'],
|
||||
config: Type[BaseConfig],
|
||||
validate_on_init: bool,
|
||||
dc_cls_doc: str,
|
||||
) -> None:
|
||||
"""
|
||||
We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass
|
||||
it won't even exist (code is generated on the fly by `dataclasses`)
|
||||
By default, we run validation after `__init__` or `__post_init__` if defined
|
||||
"""
|
||||
init = dc_cls.__init__
|
||||
|
||||
@wraps(init)
|
||||
def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
|
||||
if config.extra == Extra.ignore:
|
||||
init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
|
||||
|
||||
elif config.extra == Extra.allow:
|
||||
for k, v in kwargs.items():
|
||||
self.__dict__.setdefault(k, v)
|
||||
init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
|
||||
|
||||
else:
|
||||
init(self, *args, **kwargs)
|
||||
|
||||
if hasattr(dc_cls, '__post_init__'):
|
||||
try:
|
||||
post_init = dc_cls.__post_init__.__wrapped__ # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
post_init = dc_cls.__post_init__
|
||||
|
||||
@wraps(post_init)
|
||||
def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
|
||||
if config.post_init_call == 'before_validation':
|
||||
post_init(self, *args, **kwargs)
|
||||
|
||||
if self.__class__.__pydantic_run_validation__:
|
||||
self.__pydantic_validate_values__()
|
||||
if hasattr(self, '__post_init_post_parse__'):
|
||||
self.__post_init_post_parse__(*args, **kwargs)
|
||||
|
||||
if config.post_init_call == 'after_validation':
|
||||
post_init(self, *args, **kwargs)
|
||||
|
||||
setattr(dc_cls, '__init__', handle_extra_init)
|
||||
setattr(dc_cls, '__post_init__', new_post_init)
|
||||
|
||||
else:
|
||||
|
||||
@wraps(init)
|
||||
def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
|
||||
handle_extra_init(self, *args, **kwargs)
|
||||
|
||||
if self.__class__.__pydantic_run_validation__:
|
||||
self.__pydantic_validate_values__()
|
||||
|
||||
if hasattr(self, '__post_init_post_parse__'):
|
||||
# We need to find again the initvars. To do that we use `__dataclass_fields__` instead of
|
||||
# public method `dataclasses.fields`
|
||||
|
||||
# get all initvars and their default values
|
||||
initvars_and_values: Dict[str, Any] = {}
|
||||
for i, f in enumerate(self.__class__.__dataclass_fields__.values()):
|
||||
if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined]
|
||||
try:
|
||||
# set arg value by default
|
||||
initvars_and_values[f.name] = args[i]
|
||||
except IndexError:
|
||||
initvars_and_values[f.name] = kwargs.get(f.name, f.default)
|
||||
|
||||
self.__post_init_post_parse__(**initvars_and_values)
|
||||
|
||||
setattr(dc_cls, '__init__', new_init)
|
||||
|
||||
setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init))
|
||||
setattr(dc_cls, '__pydantic_initialised__', False)
|
||||
setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc))
|
||||
setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values)
|
||||
setattr(dc_cls, '__validate__', classmethod(_validate_dataclass))
|
||||
setattr(dc_cls, '__get_validators__', classmethod(_get_validators))
|
||||
|
||||
if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen:
|
||||
setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr)
|
||||
|
||||
|
||||
def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator':
|
||||
yield cls.__validate__
|
||||
|
||||
|
||||
def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
|
||||
with set_validation(cls, True):
|
||||
if isinstance(v, cls):
|
||||
v.__pydantic_validate_values__()
|
||||
return v
|
||||
elif isinstance(v, (list, tuple)):
|
||||
return cls(*v)
|
||||
elif isinstance(v, dict):
|
||||
return cls(**v)
|
||||
else:
|
||||
raise DataclassTypeError(class_name=cls.__name__)
|
||||
|
||||
|
||||
def create_pydantic_model_from_dataclass(
|
||||
dc_cls: Type['Dataclass'],
|
||||
config: Type[Any] = BaseConfig,
|
||||
dc_cls_doc: Optional[str] = None,
|
||||
) -> Type['BaseModel']:
|
||||
field_definitions: Dict[str, Any] = {}
|
||||
for field in dataclasses.fields(dc_cls):
|
||||
default: Any = Undefined
|
||||
default_factory: Optional['NoArgAnyCallable'] = None
|
||||
field_info: FieldInfo
|
||||
|
||||
if field.default is not dataclasses.MISSING:
|
||||
default = field.default
|
||||
elif field.default_factory is not dataclasses.MISSING:
|
||||
default_factory = field.default_factory
|
||||
else:
|
||||
default = Required
|
||||
|
||||
if isinstance(default, FieldInfo):
|
||||
field_info = default
|
||||
dc_cls.__pydantic_has_field_info_default__ = True
|
||||
else:
|
||||
field_info = Field(default=default, default_factory=default_factory, **field.metadata)
|
||||
|
||||
field_definitions[field.name] = (field.type, field_info)
|
||||
|
||||
validators = gather_all_validators(dc_cls)
|
||||
model: Type['BaseModel'] = create_model(
|
||||
dc_cls.__name__,
|
||||
__config__=config,
|
||||
__module__=dc_cls.__module__,
|
||||
__validators__=validators,
|
||||
__cls_kwargs__={'__resolve_forward_refs__': False},
|
||||
**field_definitions,
|
||||
)
|
||||
model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or ''
|
||||
return model
|
||||
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
|
||||
def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
|
||||
return isinstance(getattr(type(obj), k, None), cached_property)
|
||||
|
||||
else:
|
||||
|
||||
def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _dataclass_validate_values(self: 'Dataclass') -> None:
|
||||
# validation errors can occur if this function is called twice on an already initialised dataclass.
|
||||
# for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property
|
||||
if getattr(self, '__pydantic_initialised__'):
|
||||
return
|
||||
if getattr(self, '__pydantic_has_field_info_default__', False):
|
||||
# We need to remove `FieldInfo` values since they are not valid as input
|
||||
# It's ok to do that because they are obviously the default values!
|
||||
input_data = {
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if not (isinstance(v, FieldInfo) or _is_field_cached_property(self, k))
|
||||
}
|
||||
else:
|
||||
input_data = {k: v for k, v in self.__dict__.items() if not _is_field_cached_property(self, k)}
|
||||
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
|
||||
if validation_error:
|
||||
raise validation_error
|
||||
self.__dict__.update(d)
|
||||
object.__setattr__(self, '__pydantic_initialised__', True)
|
||||
|
||||
|
||||
def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None:
|
||||
if self.__pydantic_initialised__:
|
||||
d = dict(self.__dict__)
|
||||
d.pop(name, None)
|
||||
known_field = self.__pydantic_model__.__fields__.get(name, None)
|
||||
if known_field:
|
||||
value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
|
||||
if error_:
|
||||
raise ValidationError([error_], self.__class__)
|
||||
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
|
||||
def is_builtin_dataclass(_cls: Type[Any]) -> bool:
|
||||
"""
|
||||
Whether a class is a stdlib dataclass
|
||||
(useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass)
|
||||
|
||||
we check that
|
||||
- `_cls` is a dataclass
|
||||
- `_cls` is not a processed pydantic dataclass (with a basemodel attached)
|
||||
- `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass
|
||||
e.g.
|
||||
```
|
||||
@dataclasses.dataclass
|
||||
class A:
|
||||
x: int
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class B(A):
|
||||
y: int
|
||||
```
|
||||
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
|
||||
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
|
||||
"""
|
||||
return (
|
||||
dataclasses.is_dataclass(_cls)
|
||||
and not hasattr(_cls, '__pydantic_model__')
|
||||
and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
|
||||
)
|
||||
|
||||
|
||||
def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator':
|
||||
"""
|
||||
Create a pydantic.dataclass from a builtin dataclass to add type validation
|
||||
and yield the validators
|
||||
It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
|
||||
"""
|
||||
yield from _get_validators(dataclass(dc_cls, config=config, use_proxy=True))
|
||||
248
venv/lib/python3.11/site-packages/pydantic/v1/datetime_parse.py
Normal file
248
venv/lib/python3.11/site-packages/pydantic/v1/datetime_parse.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Functions to parse datetime objects.
|
||||
|
||||
We're using regular expressions rather than time.strptime because:
|
||||
- They provide both validation and parsing.
|
||||
- They're more flexible for datetimes.
|
||||
- The date/datetime/time constructors produce friendlier error messages.
|
||||
|
||||
Stolen from https://raw.githubusercontent.com/django/django/main/django/utils/dateparse.py at
|
||||
9718fa2e8abe430c3526a9278dd976443d4ae3c6
|
||||
|
||||
Changed to:
|
||||
* use standard python datetime types not django.utils.timezone
|
||||
* raise ValueError when regex doesn't match rather than returning None
|
||||
* support parsing unix timestamps for dates and datetimes
|
||||
"""
|
||||
import re
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from pydantic.v1 import errors
|
||||
|
||||
date_expr = r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
|
||||
time_expr = (
|
||||
r'(?P<hour>\d{1,2}):(?P<minute>\d{1,2})'
|
||||
r'(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?'
|
||||
r'(?P<tzinfo>Z|[+-]\d{2}(?::?\d{2})?)?$'
|
||||
)
|
||||
|
||||
date_re = re.compile(f'{date_expr}$')
|
||||
time_re = re.compile(time_expr)
|
||||
datetime_re = re.compile(f'{date_expr}[T ]{time_expr}')
|
||||
|
||||
standard_duration_re = re.compile(
|
||||
r'^'
|
||||
r'(?:(?P<days>-?\d+) (days?, )?)?'
|
||||
r'((?:(?P<hours>-?\d+):)(?=\d+:\d+))?'
|
||||
r'(?:(?P<minutes>-?\d+):)?'
|
||||
r'(?P<seconds>-?\d+)'
|
||||
r'(?:\.(?P<microseconds>\d{1,6})\d{0,6})?'
|
||||
r'$'
|
||||
)
|
||||
|
||||
# Support the sections of ISO 8601 date representation that are accepted by timedelta
|
||||
iso8601_duration_re = re.compile(
|
||||
r'^(?P<sign>[-+]?)'
|
||||
r'P'
|
||||
r'(?:(?P<days>\d+(.\d+)?)D)?'
|
||||
r'(?:T'
|
||||
r'(?:(?P<hours>\d+(.\d+)?)H)?'
|
||||
r'(?:(?P<minutes>\d+(.\d+)?)M)?'
|
||||
r'(?:(?P<seconds>\d+(.\d+)?)S)?'
|
||||
r')?'
|
||||
r'$'
|
||||
)
|
||||
|
||||
EPOCH = datetime(1970, 1, 1)
|
||||
# if greater than this, the number is in ms, if less than or equal it's in seconds
|
||||
# (in seconds this is 11th October 2603, in ms it's 20th August 1970)
|
||||
MS_WATERSHED = int(2e10)
|
||||
# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9
|
||||
MAX_NUMBER = int(3e20)
|
||||
StrBytesIntFloat = Union[str, bytes, int, float]
|
||||
|
||||
|
||||
def get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]:
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
except TypeError:
|
||||
raise TypeError(f'invalid type; expected {native_expected_type}, string, bytes, int or float')
|
||||
|
||||
|
||||
def from_unix_seconds(seconds: Union[int, float]) -> datetime:
|
||||
if seconds > MAX_NUMBER:
|
||||
return datetime.max
|
||||
elif seconds < -MAX_NUMBER:
|
||||
return datetime.min
|
||||
|
||||
while abs(seconds) > MS_WATERSHED:
|
||||
seconds /= 1000
|
||||
dt = EPOCH + timedelta(seconds=seconds)
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _parse_timezone(value: Optional[str], error: Type[Exception]) -> Union[None, int, timezone]:
|
||||
if value == 'Z':
|
||||
return timezone.utc
|
||||
elif value is not None:
|
||||
offset_mins = int(value[-2:]) if len(value) > 3 else 0
|
||||
offset = 60 * int(value[1:3]) + offset_mins
|
||||
if value[0] == '-':
|
||||
offset = -offset
|
||||
try:
|
||||
return timezone(timedelta(minutes=offset))
|
||||
except ValueError:
|
||||
raise error()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
|
||||
"""
|
||||
Parse a date/int/float/string and return a datetime.date.
|
||||
|
||||
Raise ValueError if the input is well formatted but not a valid date.
|
||||
Raise ValueError if the input isn't well formatted.
|
||||
"""
|
||||
if isinstance(value, date):
|
||||
if isinstance(value, datetime):
|
||||
return value.date()
|
||||
else:
|
||||
return value
|
||||
|
||||
number = get_numeric(value, 'date')
|
||||
if number is not None:
|
||||
return from_unix_seconds(number).date()
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
match = date_re.match(value) # type: ignore
|
||||
if match is None:
|
||||
raise errors.DateError()
|
||||
|
||||
kw = {k: int(v) for k, v in match.groupdict().items()}
|
||||
|
||||
try:
|
||||
return date(**kw)
|
||||
except ValueError:
|
||||
raise errors.DateError()
|
||||
|
||||
|
||||
def parse_time(value: Union[time, StrBytesIntFloat]) -> time:
|
||||
"""
|
||||
Parse a time/string and return a datetime.time.
|
||||
|
||||
Raise ValueError if the input is well formatted but not a valid time.
|
||||
Raise ValueError if the input isn't well formatted, in particular if it contains an offset.
|
||||
"""
|
||||
if isinstance(value, time):
|
||||
return value
|
||||
|
||||
number = get_numeric(value, 'time')
|
||||
if number is not None:
|
||||
if number >= 86400:
|
||||
# doesn't make sense since the time time loop back around to 0
|
||||
raise errors.TimeError()
|
||||
return (datetime.min + timedelta(seconds=number)).time()
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
match = time_re.match(value) # type: ignore
|
||||
if match is None:
|
||||
raise errors.TimeError()
|
||||
|
||||
kw = match.groupdict()
|
||||
if kw['microsecond']:
|
||||
kw['microsecond'] = kw['microsecond'].ljust(6, '0')
|
||||
|
||||
tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.TimeError)
|
||||
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
|
||||
kw_['tzinfo'] = tzinfo
|
||||
|
||||
try:
|
||||
return time(**kw_) # type: ignore
|
||||
except ValueError:
|
||||
raise errors.TimeError()
|
||||
|
||||
|
||||
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
|
||||
"""
|
||||
Parse a datetime/int/float/string and return a datetime.datetime.
|
||||
|
||||
This function supports time zone offsets. When the input contains one,
|
||||
the output uses a timezone with a fixed offset from UTC.
|
||||
|
||||
Raise ValueError if the input is well formatted but not a valid datetime.
|
||||
Raise ValueError if the input isn't well formatted.
|
||||
"""
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
|
||||
number = get_numeric(value, 'datetime')
|
||||
if number is not None:
|
||||
return from_unix_seconds(number)
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
match = datetime_re.match(value) # type: ignore
|
||||
if match is None:
|
||||
raise errors.DateTimeError()
|
||||
|
||||
kw = match.groupdict()
|
||||
if kw['microsecond']:
|
||||
kw['microsecond'] = kw['microsecond'].ljust(6, '0')
|
||||
|
||||
tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.DateTimeError)
|
||||
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
|
||||
kw_['tzinfo'] = tzinfo
|
||||
|
||||
try:
|
||||
return datetime(**kw_) # type: ignore
|
||||
except ValueError:
|
||||
raise errors.DateTimeError()
|
||||
|
||||
|
||||
def parse_duration(value: StrBytesIntFloat) -> timedelta:
|
||||
"""
|
||||
Parse a duration int/float/string and return a datetime.timedelta.
|
||||
|
||||
The preferred format for durations in Django is '%d %H:%M:%S.%f'.
|
||||
|
||||
Also supports ISO 8601 representation.
|
||||
"""
|
||||
if isinstance(value, timedelta):
|
||||
return value
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
# below code requires a string
|
||||
value = f'{value:f}'
|
||||
elif isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
try:
|
||||
match = standard_duration_re.match(value) or iso8601_duration_re.match(value)
|
||||
except TypeError:
|
||||
raise TypeError('invalid type; expected timedelta, string, bytes, int or float')
|
||||
|
||||
if not match:
|
||||
raise errors.DurationError()
|
||||
|
||||
kw = match.groupdict()
|
||||
sign = -1 if kw.pop('sign', '+') == '-' else 1
|
||||
if kw.get('microseconds'):
|
||||
kw['microseconds'] = kw['microseconds'].ljust(6, '0')
|
||||
|
||||
if kw.get('seconds') and kw.get('microseconds') and kw['seconds'].startswith('-'):
|
||||
kw['microseconds'] = '-' + kw['microseconds']
|
||||
|
||||
kw_ = {k: float(v) for k, v in kw.items() if v is not None}
|
||||
|
||||
return sign * timedelta(**kw_)
|
||||
264
venv/lib/python3.11/site-packages/pydantic/v1/decorator.py
Normal file
264
venv/lib/python3.11/site-packages/pydantic/v1/decorator.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from pydantic.v1 import validator
|
||||
from pydantic.v1.config import Extra
|
||||
from pydantic.v1.errors import ConfigError
|
||||
from pydantic.v1.main import BaseModel, create_model
|
||||
from pydantic.v1.typing import get_all_type_hints
|
||||
from pydantic.v1.utils import to_camel
|
||||
|
||||
__all__ = ('validate_arguments',)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import AnyCallable
|
||||
|
||||
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
|
||||
ConfigType = Union[None, Type[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT':
|
||||
...
|
||||
|
||||
|
||||
def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any:
|
||||
"""
|
||||
Decorator to validate the arguments passed to a function.
|
||||
"""
|
||||
|
||||
def validate(_func: 'AnyCallable') -> 'AnyCallable':
|
||||
vd = ValidatedFunction(_func, config)
|
||||
|
||||
@wraps(_func)
|
||||
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
|
||||
return vd.call(*args, **kwargs)
|
||||
|
||||
wrapper_function.vd = vd # type: ignore
|
||||
wrapper_function.validate = vd.init_model_instance # type: ignore
|
||||
wrapper_function.raw_function = vd.raw_function # type: ignore
|
||||
wrapper_function.model = vd.model # type: ignore
|
||||
return wrapper_function
|
||||
|
||||
if func:
|
||||
return validate(func)
|
||||
else:
|
||||
return validate
|
||||
|
||||
|
||||
ALT_V_ARGS = 'v__args'
|
||||
ALT_V_KWARGS = 'v__kwargs'
|
||||
V_POSITIONAL_ONLY_NAME = 'v__positional_only'
|
||||
V_DUPLICATE_KWARGS = 'v__duplicate_kwargs'
|
||||
|
||||
|
||||
class ValidatedFunction:
|
||||
def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901
|
||||
from inspect import Parameter, signature
|
||||
|
||||
parameters: Mapping[str, Parameter] = signature(function).parameters
|
||||
|
||||
if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}:
|
||||
raise ConfigError(
|
||||
f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" '
|
||||
f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator'
|
||||
)
|
||||
|
||||
self.raw_function = function
|
||||
self.arg_mapping: Dict[int, str] = {}
|
||||
self.positional_only_args = set()
|
||||
self.v_args_name = 'args'
|
||||
self.v_kwargs_name = 'kwargs'
|
||||
|
||||
type_hints = get_all_type_hints(function)
|
||||
takes_args = False
|
||||
takes_kwargs = False
|
||||
fields: Dict[str, Tuple[Any, Any]] = {}
|
||||
for i, (name, p) in enumerate(parameters.items()):
|
||||
if p.annotation is p.empty:
|
||||
annotation = Any
|
||||
else:
|
||||
annotation = type_hints[name]
|
||||
|
||||
default = ... if p.default is p.empty else p.default
|
||||
if p.kind == Parameter.POSITIONAL_ONLY:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_POSITIONAL_ONLY_NAME] = List[str], None
|
||||
self.positional_only_args.add(name)
|
||||
elif p.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_DUPLICATE_KWARGS] = List[str], None
|
||||
elif p.kind == Parameter.KEYWORD_ONLY:
|
||||
fields[name] = annotation, default
|
||||
elif p.kind == Parameter.VAR_POSITIONAL:
|
||||
self.v_args_name = name
|
||||
fields[name] = Tuple[annotation, ...], None
|
||||
takes_args = True
|
||||
else:
|
||||
assert p.kind == Parameter.VAR_KEYWORD, p.kind
|
||||
self.v_kwargs_name = name
|
||||
fields[name] = Dict[str, annotation], None # type: ignore
|
||||
takes_kwargs = True
|
||||
|
||||
# these checks avoid a clash between "args" and a field with that name
|
||||
if not takes_args and self.v_args_name in fields:
|
||||
self.v_args_name = ALT_V_ARGS
|
||||
|
||||
# same with "kwargs"
|
||||
if not takes_kwargs and self.v_kwargs_name in fields:
|
||||
self.v_kwargs_name = ALT_V_KWARGS
|
||||
|
||||
if not takes_args:
|
||||
# we add the field so validation below can raise the correct exception
|
||||
fields[self.v_args_name] = List[Any], None
|
||||
|
||||
if not takes_kwargs:
|
||||
# same with kwargs
|
||||
fields[self.v_kwargs_name] = Dict[Any, Any], None
|
||||
|
||||
self.create_model(fields, takes_args, takes_kwargs, config)
|
||||
|
||||
def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel:
|
||||
values = self.build_values(args, kwargs)
|
||||
return self.model(**values)
|
||||
|
||||
def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||
m = self.init_model_instance(*args, **kwargs)
|
||||
return self.execute(m)
|
||||
|
||||
def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
values: Dict[str, Any] = {}
|
||||
if args:
|
||||
arg_iter = enumerate(args)
|
||||
while True:
|
||||
try:
|
||||
i, a = next(arg_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
arg_name = self.arg_mapping.get(i)
|
||||
if arg_name is not None:
|
||||
values[arg_name] = a
|
||||
else:
|
||||
values[self.v_args_name] = [a] + [a for _, a in arg_iter]
|
||||
break
|
||||
|
||||
var_kwargs: Dict[str, Any] = {}
|
||||
wrong_positional_args = []
|
||||
duplicate_kwargs = []
|
||||
fields_alias = [
|
||||
field.alias
|
||||
for name, field in self.model.__fields__.items()
|
||||
if name not in (self.v_args_name, self.v_kwargs_name)
|
||||
]
|
||||
non_var_fields = set(self.model.__fields__) - {self.v_args_name, self.v_kwargs_name}
|
||||
for k, v in kwargs.items():
|
||||
if k in non_var_fields or k in fields_alias:
|
||||
if k in self.positional_only_args:
|
||||
wrong_positional_args.append(k)
|
||||
if k in values:
|
||||
duplicate_kwargs.append(k)
|
||||
values[k] = v
|
||||
else:
|
||||
var_kwargs[k] = v
|
||||
|
||||
if var_kwargs:
|
||||
values[self.v_kwargs_name] = var_kwargs
|
||||
if wrong_positional_args:
|
||||
values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args
|
||||
if duplicate_kwargs:
|
||||
values[V_DUPLICATE_KWARGS] = duplicate_kwargs
|
||||
return values
|
||||
|
||||
def execute(self, m: BaseModel) -> Any:
|
||||
d = {k: v for k, v in m._iter() if k in m.__fields_set__ or m.__fields__[k].default_factory}
|
||||
var_kwargs = d.pop(self.v_kwargs_name, {})
|
||||
|
||||
if self.v_args_name in d:
|
||||
args_: List[Any] = []
|
||||
in_kwargs = False
|
||||
kwargs = {}
|
||||
for name, value in d.items():
|
||||
if in_kwargs:
|
||||
kwargs[name] = value
|
||||
elif name == self.v_args_name:
|
||||
args_ += value
|
||||
in_kwargs = True
|
||||
else:
|
||||
args_.append(value)
|
||||
return self.raw_function(*args_, **kwargs, **var_kwargs)
|
||||
elif self.positional_only_args:
|
||||
args_ = []
|
||||
kwargs = {}
|
||||
for name, value in d.items():
|
||||
if name in self.positional_only_args:
|
||||
args_.append(value)
|
||||
else:
|
||||
kwargs[name] = value
|
||||
return self.raw_function(*args_, **kwargs, **var_kwargs)
|
||||
else:
|
||||
return self.raw_function(**d, **var_kwargs)
|
||||
|
||||
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
|
||||
pos_args = len(self.arg_mapping)
|
||||
|
||||
class CustomConfig:
|
||||
pass
|
||||
|
||||
if not TYPE_CHECKING: # pragma: no branch
|
||||
if isinstance(config, dict):
|
||||
CustomConfig = type('Config', (), config) # noqa: F811
|
||||
elif config is not None:
|
||||
CustomConfig = config # noqa: F811
|
||||
|
||||
if hasattr(CustomConfig, 'fields') or hasattr(CustomConfig, 'alias_generator'):
|
||||
raise ConfigError(
|
||||
'Setting the "fields" and "alias_generator" property on custom Config for '
|
||||
'@validate_arguments is not yet supported, please remove.'
|
||||
)
|
||||
|
||||
class DecoratorBaseModel(BaseModel):
|
||||
@validator(self.v_args_name, check_fields=False, allow_reuse=True)
|
||||
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
|
||||
if takes_args or v is None:
|
||||
return v
|
||||
|
||||
raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given')
|
||||
|
||||
@validator(self.v_kwargs_name, check_fields=False, allow_reuse=True)
|
||||
def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if takes_kwargs or v is None:
|
||||
return v
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v.keys()))
|
||||
raise TypeError(f'unexpected keyword argument{plural}: {keys}')
|
||||
|
||||
@validator(V_POSITIONAL_ONLY_NAME, check_fields=False, allow_reuse=True)
|
||||
def check_positional_only(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v))
|
||||
raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}')
|
||||
|
||||
@validator(V_DUPLICATE_KWARGS, check_fields=False, allow_reuse=True)
|
||||
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v))
|
||||
raise TypeError(f'multiple values for argument{plural}: {keys}')
|
||||
|
||||
class Config(CustomConfig):
|
||||
extra = getattr(CustomConfig, 'extra', Extra.forbid)
|
||||
|
||||
self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)
|
||||
350
venv/lib/python3.11/site-packages/pydantic/v1/env_settings.py
Normal file
350
venv/lib/python3.11/site-packages/pydantic/v1/env_settings.py
Normal file
@@ -0,0 +1,350 @@
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic.v1.config import BaseConfig, Extra
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.main import BaseModel
|
||||
from pydantic.v1.types import JsonWrapper
|
||||
from pydantic.v1.typing import StrPath, display_as_type, get_origin, is_union
|
||||
from pydantic.v1.utils import deep_update, lenient_issubclass, path_type, sequence_like
|
||||
|
||||
env_file_sentinel = str(object())
|
||||
|
||||
SettingsSourceCallable = Callable[['BaseSettings'], Dict[str, Any]]
|
||||
DotenvType = Union[StrPath, List[StrPath], Tuple[StrPath, ...]]
|
||||
|
||||
|
||||
class SettingsError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class BaseSettings(BaseModel):
|
||||
"""
|
||||
Base class for settings, allowing values to be overridden by environment variables.
|
||||
|
||||
This is useful in production for secrets you do not wish to save in code, it plays nicely with docker(-compose),
|
||||
Heroku and any 12 factor app design.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
__pydantic_self__,
|
||||
_env_file: Optional[DotenvType] = env_file_sentinel,
|
||||
_env_file_encoding: Optional[str] = None,
|
||||
_env_nested_delimiter: Optional[str] = None,
|
||||
_secrets_dir: Optional[StrPath] = None,
|
||||
**values: Any,
|
||||
) -> None:
|
||||
# Uses something other than `self` the first arg to allow "self" as a settable attribute
|
||||
super().__init__(
|
||||
**__pydantic_self__._build_values(
|
||||
values,
|
||||
_env_file=_env_file,
|
||||
_env_file_encoding=_env_file_encoding,
|
||||
_env_nested_delimiter=_env_nested_delimiter,
|
||||
_secrets_dir=_secrets_dir,
|
||||
)
|
||||
)
|
||||
|
||||
def _build_values(
|
||||
self,
|
||||
init_kwargs: Dict[str, Any],
|
||||
_env_file: Optional[DotenvType] = None,
|
||||
_env_file_encoding: Optional[str] = None,
|
||||
_env_nested_delimiter: Optional[str] = None,
|
||||
_secrets_dir: Optional[StrPath] = None,
|
||||
) -> Dict[str, Any]:
|
||||
# Configure built-in sources
|
||||
init_settings = InitSettingsSource(init_kwargs=init_kwargs)
|
||||
env_settings = EnvSettingsSource(
|
||||
env_file=(_env_file if _env_file != env_file_sentinel else self.__config__.env_file),
|
||||
env_file_encoding=(
|
||||
_env_file_encoding if _env_file_encoding is not None else self.__config__.env_file_encoding
|
||||
),
|
||||
env_nested_delimiter=(
|
||||
_env_nested_delimiter if _env_nested_delimiter is not None else self.__config__.env_nested_delimiter
|
||||
),
|
||||
env_prefix_len=len(self.__config__.env_prefix),
|
||||
)
|
||||
file_secret_settings = SecretsSettingsSource(secrets_dir=_secrets_dir or self.__config__.secrets_dir)
|
||||
# Provide a hook to set built-in sources priority and add / remove sources
|
||||
sources = self.__config__.customise_sources(
|
||||
init_settings=init_settings, env_settings=env_settings, file_secret_settings=file_secret_settings
|
||||
)
|
||||
if sources:
|
||||
return deep_update(*reversed([source(self) for source in sources]))
|
||||
else:
|
||||
# no one should mean to do this, but I think returning an empty dict is marginally preferable
|
||||
# to an informative error and much better than a confusing error
|
||||
return {}
|
||||
|
||||
class Config(BaseConfig):
|
||||
env_prefix: str = ''
|
||||
env_file: Optional[DotenvType] = None
|
||||
env_file_encoding: Optional[str] = None
|
||||
env_nested_delimiter: Optional[str] = None
|
||||
secrets_dir: Optional[StrPath] = None
|
||||
validate_all: bool = True
|
||||
extra: Extra = Extra.forbid
|
||||
arbitrary_types_allowed: bool = True
|
||||
case_sensitive: bool = False
|
||||
|
||||
@classmethod
|
||||
def prepare_field(cls, field: ModelField) -> None:
|
||||
env_names: Union[List[str], AbstractSet[str]]
|
||||
field_info_from_config = cls.get_field_info(field.name)
|
||||
|
||||
env = field_info_from_config.get('env') or field.field_info.extra.get('env')
|
||||
if env is None:
|
||||
if field.has_alias:
|
||||
warnings.warn(
|
||||
'aliases are no longer used by BaseSettings to define which environment variables to read. '
|
||||
'Instead use the "env" field setting. '
|
||||
'See https://pydantic-docs.helpmanual.io/usage/settings/#environment-variable-names',
|
||||
FutureWarning,
|
||||
)
|
||||
env_names = {cls.env_prefix + field.name}
|
||||
elif isinstance(env, str):
|
||||
env_names = {env}
|
||||
elif isinstance(env, (set, frozenset)):
|
||||
env_names = env
|
||||
elif sequence_like(env):
|
||||
env_names = list(env)
|
||||
else:
|
||||
raise TypeError(f'invalid field env: {env!r} ({display_as_type(env)}); should be string, list or set')
|
||||
|
||||
if not cls.case_sensitive:
|
||||
env_names = env_names.__class__(n.lower() for n in env_names)
|
||||
field.field_info.extra['env_names'] = env_names
|
||||
|
||||
@classmethod
|
||||
def customise_sources(
|
||||
cls,
|
||||
init_settings: SettingsSourceCallable,
|
||||
env_settings: SettingsSourceCallable,
|
||||
file_secret_settings: SettingsSourceCallable,
|
||||
) -> Tuple[SettingsSourceCallable, ...]:
|
||||
return init_settings, env_settings, file_secret_settings
|
||||
|
||||
@classmethod
|
||||
def parse_env_var(cls, field_name: str, raw_val: str) -> Any:
|
||||
return cls.json_loads(raw_val)
|
||||
|
||||
# populated by the metaclass using the Config class defined above, annotated here to help IDEs only
|
||||
__config__: ClassVar[Type[Config]]
|
||||
|
||||
|
||||
class InitSettingsSource:
|
||||
__slots__ = ('init_kwargs',)
|
||||
|
||||
def __init__(self, init_kwargs: Dict[str, Any]):
|
||||
self.init_kwargs = init_kwargs
|
||||
|
||||
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
|
||||
return self.init_kwargs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})'
|
||||
|
||||
|
||||
class EnvSettingsSource:
|
||||
__slots__ = ('env_file', 'env_file_encoding', 'env_nested_delimiter', 'env_prefix_len')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_file: Optional[DotenvType],
|
||||
env_file_encoding: Optional[str],
|
||||
env_nested_delimiter: Optional[str] = None,
|
||||
env_prefix_len: int = 0,
|
||||
):
|
||||
self.env_file: Optional[DotenvType] = env_file
|
||||
self.env_file_encoding: Optional[str] = env_file_encoding
|
||||
self.env_nested_delimiter: Optional[str] = env_nested_delimiter
|
||||
self.env_prefix_len: int = env_prefix_len
|
||||
|
||||
def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901
|
||||
"""
|
||||
Build environment variables suitable for passing to the Model.
|
||||
"""
|
||||
d: Dict[str, Any] = {}
|
||||
|
||||
if settings.__config__.case_sensitive:
|
||||
env_vars: Mapping[str, Optional[str]] = os.environ
|
||||
else:
|
||||
env_vars = {k.lower(): v for k, v in os.environ.items()}
|
||||
|
||||
dotenv_vars = self._read_env_files(settings.__config__.case_sensitive)
|
||||
if dotenv_vars:
|
||||
env_vars = {**dotenv_vars, **env_vars}
|
||||
|
||||
for field in settings.__fields__.values():
|
||||
env_val: Optional[str] = None
|
||||
for env_name in field.field_info.extra['env_names']:
|
||||
env_val = env_vars.get(env_name)
|
||||
if env_val is not None:
|
||||
break
|
||||
|
||||
is_complex, allow_parse_failure = self.field_is_complex(field)
|
||||
if is_complex:
|
||||
if env_val is None:
|
||||
# field is complex but no value found so far, try explode_env_vars
|
||||
env_val_built = self.explode_env_vars(field, env_vars)
|
||||
if env_val_built:
|
||||
d[field.alias] = env_val_built
|
||||
else:
|
||||
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
|
||||
try:
|
||||
env_val = settings.__config__.parse_env_var(field.name, env_val)
|
||||
except ValueError as e:
|
||||
if not allow_parse_failure:
|
||||
raise SettingsError(f'error parsing env var "{env_name}"') from e
|
||||
|
||||
if isinstance(env_val, dict):
|
||||
d[field.alias] = deep_update(env_val, self.explode_env_vars(field, env_vars))
|
||||
else:
|
||||
d[field.alias] = env_val
|
||||
elif env_val is not None:
|
||||
# simplest case, field is not complex, we only need to add the value if it was found
|
||||
d[field.alias] = env_val
|
||||
|
||||
return d
|
||||
|
||||
def _read_env_files(self, case_sensitive: bool) -> Dict[str, Optional[str]]:
|
||||
env_files = self.env_file
|
||||
if env_files is None:
|
||||
return {}
|
||||
|
||||
if isinstance(env_files, (str, os.PathLike)):
|
||||
env_files = [env_files]
|
||||
|
||||
dotenv_vars = {}
|
||||
for env_file in env_files:
|
||||
env_path = Path(env_file).expanduser()
|
||||
if env_path.is_file():
|
||||
dotenv_vars.update(
|
||||
read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive)
|
||||
)
|
||||
|
||||
return dotenv_vars
|
||||
|
||||
def field_is_complex(self, field: ModelField) -> Tuple[bool, bool]:
|
||||
"""
|
||||
Find out if a field is complex, and if so whether JSON errors should be ignored
|
||||
"""
|
||||
if lenient_issubclass(field.annotation, JsonWrapper):
|
||||
return False, False
|
||||
|
||||
if field.is_complex():
|
||||
allow_parse_failure = False
|
||||
elif is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields):
|
||||
allow_parse_failure = True
|
||||
else:
|
||||
return False, False
|
||||
|
||||
return True, allow_parse_failure
|
||||
|
||||
def explode_env_vars(self, field: ModelField, env_vars: Mapping[str, Optional[str]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
|
||||
|
||||
This is applied to a single field, hence filtering by env_var prefix.
|
||||
"""
|
||||
prefixes = [f'{env_name}{self.env_nested_delimiter}' for env_name in field.field_info.extra['env_names']]
|
||||
result: Dict[str, Any] = {}
|
||||
for env_name, env_val in env_vars.items():
|
||||
if not any(env_name.startswith(prefix) for prefix in prefixes):
|
||||
continue
|
||||
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
|
||||
env_name_without_prefix = env_name[self.env_prefix_len :]
|
||||
_, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter)
|
||||
env_var = result
|
||||
for key in keys:
|
||||
env_var = env_var.setdefault(key, {})
|
||||
env_var[last_key] = env_val
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'EnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
)
|
||||
|
||||
|
||||
class SecretsSettingsSource:
|
||||
__slots__ = ('secrets_dir',)
|
||||
|
||||
def __init__(self, secrets_dir: Optional[StrPath]):
|
||||
self.secrets_dir: Optional[StrPath] = secrets_dir
|
||||
|
||||
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
|
||||
"""
|
||||
Build fields from "secrets" files.
|
||||
"""
|
||||
secrets: Dict[str, Optional[str]] = {}
|
||||
|
||||
if self.secrets_dir is None:
|
||||
return secrets
|
||||
|
||||
secrets_path = Path(self.secrets_dir).expanduser()
|
||||
|
||||
if not secrets_path.exists():
|
||||
warnings.warn(f'directory "{secrets_path}" does not exist')
|
||||
return secrets
|
||||
|
||||
if not secrets_path.is_dir():
|
||||
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type(secrets_path)}')
|
||||
|
||||
for field in settings.__fields__.values():
|
||||
for env_name in field.field_info.extra['env_names']:
|
||||
path = find_case_path(secrets_path, env_name, settings.__config__.case_sensitive)
|
||||
if not path:
|
||||
# path does not exist, we currently don't return a warning for this
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
secret_value = path.read_text().strip()
|
||||
if field.is_complex():
|
||||
try:
|
||||
secret_value = settings.__config__.parse_env_var(field.name, secret_value)
|
||||
except ValueError as e:
|
||||
raise SettingsError(f'error parsing env var "{env_name}"') from e
|
||||
|
||||
secrets[field.alias] = secret_value
|
||||
else:
|
||||
warnings.warn(
|
||||
f'attempted to load secret file "{path}" but found a {path_type(path)} instead.',
|
||||
stacklevel=4,
|
||||
)
|
||||
return secrets
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})'
|
||||
|
||||
|
||||
def read_env_file(
|
||||
file_path: StrPath, *, encoding: str = None, case_sensitive: bool = False
|
||||
) -> Dict[str, Optional[str]]:
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
except ImportError as e:
|
||||
raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e
|
||||
|
||||
file_vars: Dict[str, Optional[str]] = dotenv_values(file_path, encoding=encoding or 'utf8')
|
||||
if not case_sensitive:
|
||||
return {k.lower(): v for k, v in file_vars.items()}
|
||||
else:
|
||||
return file_vars
|
||||
|
||||
|
||||
def find_case_path(dir_path: Path, file_name: str, case_sensitive: bool) -> Optional[Path]:
|
||||
"""
|
||||
Find a file within path's directory matching filename, optionally ignoring case.
|
||||
"""
|
||||
for f in dir_path.iterdir():
|
||||
if f.name == file_name:
|
||||
return f
|
||||
elif not case_sensitive and f.name.lower() == file_name.lower():
|
||||
return f
|
||||
return None
|
||||
161
venv/lib/python3.11/site-packages/pydantic/v1/error_wrappers.py
Normal file
161
venv/lib/python3.11/site-packages/pydantic/v1/error_wrappers.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from pydantic.v1.json import pydantic_encoder
|
||||
from pydantic.v1.utils import Representation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.types import ModelOrDc
|
||||
from pydantic.v1.typing import ReprArgs
|
||||
|
||||
Loc = Tuple[Union[int, str], ...]
|
||||
|
||||
class _ErrorDictRequired(TypedDict):
|
||||
loc: Loc
|
||||
msg: str
|
||||
type: str
|
||||
|
||||
class ErrorDict(_ErrorDictRequired, total=False):
|
||||
ctx: Dict[str, Any]
|
||||
|
||||
|
||||
__all__ = 'ErrorWrapper', 'ValidationError'
|
||||
|
||||
|
||||
class ErrorWrapper(Representation):
|
||||
__slots__ = 'exc', '_loc'
|
||||
|
||||
def __init__(self, exc: Exception, loc: Union[str, 'Loc']) -> None:
|
||||
self.exc = exc
|
||||
self._loc = loc
|
||||
|
||||
def loc_tuple(self) -> 'Loc':
|
||||
if isinstance(self._loc, tuple):
|
||||
return self._loc
|
||||
else:
|
||||
return (self._loc,)
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [('exc', self.exc), ('loc', self.loc_tuple())]
|
||||
|
||||
|
||||
# ErrorList is something like Union[List[Union[List[ErrorWrapper], ErrorWrapper]], ErrorWrapper]
|
||||
# but recursive, therefore just use:
|
||||
ErrorList = Union[Sequence[Any], ErrorWrapper]
|
||||
|
||||
|
||||
class ValidationError(Representation, ValueError):
|
||||
__slots__ = 'raw_errors', 'model', '_error_cache'
|
||||
|
||||
def __init__(self, errors: Sequence[ErrorList], model: 'ModelOrDc') -> None:
|
||||
self.raw_errors = errors
|
||||
self.model = model
|
||||
self._error_cache: Optional[List['ErrorDict']] = None
|
||||
|
||||
def errors(self) -> List['ErrorDict']:
|
||||
if self._error_cache is None:
|
||||
try:
|
||||
config = self.model.__config__ # type: ignore
|
||||
except AttributeError:
|
||||
config = self.model.__pydantic_model__.__config__ # type: ignore
|
||||
self._error_cache = list(flatten_errors(self.raw_errors, config))
|
||||
return self._error_cache
|
||||
|
||||
def json(self, *, indent: Union[None, int, str] = 2) -> str:
|
||||
return json.dumps(self.errors(), indent=indent, default=pydantic_encoder)
|
||||
|
||||
def __str__(self) -> str:
|
||||
errors = self.errors()
|
||||
no_errors = len(errors)
|
||||
return (
|
||||
f'{no_errors} validation error{"" if no_errors == 1 else "s"} for {self.model.__name__}\n'
|
||||
f'{display_errors(errors)}'
|
||||
)
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [('model', self.model.__name__), ('errors', self.errors())]
|
||||
|
||||
|
||||
def display_errors(errors: List['ErrorDict']) -> str:
|
||||
return '\n'.join(f'{_display_error_loc(e)}\n {e["msg"]} ({_display_error_type_and_ctx(e)})' for e in errors)
|
||||
|
||||
|
||||
def _display_error_loc(error: 'ErrorDict') -> str:
|
||||
return ' -> '.join(str(e) for e in error['loc'])
|
||||
|
||||
|
||||
def _display_error_type_and_ctx(error: 'ErrorDict') -> str:
|
||||
t = 'type=' + error['type']
|
||||
ctx = error.get('ctx')
|
||||
if ctx:
|
||||
return t + ''.join(f'; {k}={v}' for k, v in ctx.items())
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def flatten_errors(
|
||||
errors: Sequence[Any], config: Type['BaseConfig'], loc: Optional['Loc'] = None
|
||||
) -> Generator['ErrorDict', None, None]:
|
||||
for error in errors:
|
||||
if isinstance(error, ErrorWrapper):
|
||||
if loc:
|
||||
error_loc = loc + error.loc_tuple()
|
||||
else:
|
||||
error_loc = error.loc_tuple()
|
||||
|
||||
if isinstance(error.exc, ValidationError):
|
||||
yield from flatten_errors(error.exc.raw_errors, config, error_loc)
|
||||
else:
|
||||
yield error_dict(error.exc, config, error_loc)
|
||||
elif isinstance(error, list):
|
||||
yield from flatten_errors(error, config, loc=loc)
|
||||
else:
|
||||
raise RuntimeError(f'Unknown error object: {error}')
|
||||
|
||||
|
||||
def error_dict(exc: Exception, config: Type['BaseConfig'], loc: 'Loc') -> 'ErrorDict':
|
||||
type_ = get_exc_type(exc.__class__)
|
||||
msg_template = config.error_msg_templates.get(type_) or getattr(exc, 'msg_template', None)
|
||||
ctx = exc.__dict__
|
||||
if msg_template:
|
||||
msg = msg_template.format(**ctx)
|
||||
else:
|
||||
msg = str(exc)
|
||||
|
||||
d: 'ErrorDict' = {'loc': loc, 'msg': msg, 'type': type_}
|
||||
|
||||
if ctx:
|
||||
d['ctx'] = ctx
|
||||
|
||||
return d
|
||||
|
||||
|
||||
_EXC_TYPE_CACHE: Dict[Type[Exception], str] = {}
|
||||
|
||||
|
||||
def get_exc_type(cls: Type[Exception]) -> str:
|
||||
# slightly more efficient than using lru_cache since we don't need to worry about the cache filling up
|
||||
try:
|
||||
return _EXC_TYPE_CACHE[cls]
|
||||
except KeyError:
|
||||
r = _get_exc_type(cls)
|
||||
_EXC_TYPE_CACHE[cls] = r
|
||||
return r
|
||||
|
||||
|
||||
def _get_exc_type(cls: Type[Exception]) -> str:
|
||||
if issubclass(cls, AssertionError):
|
||||
return 'assertion_error'
|
||||
|
||||
base_name = 'type_error' if issubclass(cls, TypeError) else 'value_error'
|
||||
if cls in (TypeError, ValueError):
|
||||
# just TypeError or ValueError, no extra code
|
||||
return base_name
|
||||
|
||||
# if it's not a TypeError or ValueError, we just take the lowercase of the exception name
|
||||
# no chaining or snake case logic, use "code" for more complex error types.
|
||||
code = getattr(cls, 'code', None) or cls.__name__.replace('Error', '').lower()
|
||||
return base_name + '.' + code
|
||||
646
venv/lib/python3.11/site-packages/pydantic/v1/errors.py
Normal file
646
venv/lib/python3.11/site-packages/pydantic/v1/errors.py
Normal file
@@ -0,0 +1,646 @@
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, Tuple, Type, Union
|
||||
|
||||
from pydantic.v1.typing import display_as_type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import DictStrAny
|
||||
|
||||
# explicitly state exports to avoid "from pydantic.v1.errors import *" also importing Decimal, Path etc.
|
||||
__all__ = (
|
||||
'PydanticTypeError',
|
||||
'PydanticValueError',
|
||||
'ConfigError',
|
||||
'MissingError',
|
||||
'ExtraError',
|
||||
'NoneIsNotAllowedError',
|
||||
'NoneIsAllowedError',
|
||||
'WrongConstantError',
|
||||
'NotNoneError',
|
||||
'BoolError',
|
||||
'BytesError',
|
||||
'DictError',
|
||||
'EmailError',
|
||||
'UrlError',
|
||||
'UrlSchemeError',
|
||||
'UrlSchemePermittedError',
|
||||
'UrlUserInfoError',
|
||||
'UrlHostError',
|
||||
'UrlHostTldError',
|
||||
'UrlPortError',
|
||||
'UrlExtraError',
|
||||
'EnumError',
|
||||
'IntEnumError',
|
||||
'EnumMemberError',
|
||||
'IntegerError',
|
||||
'FloatError',
|
||||
'PathError',
|
||||
'PathNotExistsError',
|
||||
'PathNotAFileError',
|
||||
'PathNotADirectoryError',
|
||||
'PyObjectError',
|
||||
'SequenceError',
|
||||
'ListError',
|
||||
'SetError',
|
||||
'FrozenSetError',
|
||||
'TupleError',
|
||||
'TupleLengthError',
|
||||
'ListMinLengthError',
|
||||
'ListMaxLengthError',
|
||||
'ListUniqueItemsError',
|
||||
'SetMinLengthError',
|
||||
'SetMaxLengthError',
|
||||
'FrozenSetMinLengthError',
|
||||
'FrozenSetMaxLengthError',
|
||||
'AnyStrMinLengthError',
|
||||
'AnyStrMaxLengthError',
|
||||
'StrError',
|
||||
'StrRegexError',
|
||||
'NumberNotGtError',
|
||||
'NumberNotGeError',
|
||||
'NumberNotLtError',
|
||||
'NumberNotLeError',
|
||||
'NumberNotMultipleError',
|
||||
'DecimalError',
|
||||
'DecimalIsNotFiniteError',
|
||||
'DecimalMaxDigitsError',
|
||||
'DecimalMaxPlacesError',
|
||||
'DecimalWholeDigitsError',
|
||||
'DateTimeError',
|
||||
'DateError',
|
||||
'DateNotInThePastError',
|
||||
'DateNotInTheFutureError',
|
||||
'TimeError',
|
||||
'DurationError',
|
||||
'HashableError',
|
||||
'UUIDError',
|
||||
'UUIDVersionError',
|
||||
'ArbitraryTypeError',
|
||||
'ClassError',
|
||||
'SubclassError',
|
||||
'JsonError',
|
||||
'JsonTypeError',
|
||||
'PatternError',
|
||||
'DataclassTypeError',
|
||||
'CallableError',
|
||||
'IPvAnyAddressError',
|
||||
'IPvAnyInterfaceError',
|
||||
'IPvAnyNetworkError',
|
||||
'IPv4AddressError',
|
||||
'IPv6AddressError',
|
||||
'IPv4NetworkError',
|
||||
'IPv6NetworkError',
|
||||
'IPv4InterfaceError',
|
||||
'IPv6InterfaceError',
|
||||
'ColorError',
|
||||
'StrictBoolError',
|
||||
'NotDigitError',
|
||||
'LuhnValidationError',
|
||||
'InvalidLengthForBrand',
|
||||
'InvalidByteSize',
|
||||
'InvalidByteSizeUnit',
|
||||
'MissingDiscriminator',
|
||||
'InvalidDiscriminator',
|
||||
)
|
||||
|
||||
|
||||
def cls_kwargs(cls: Type['PydanticErrorMixin'], ctx: 'DictStrAny') -> 'PydanticErrorMixin':
|
||||
"""
|
||||
For built-in exceptions like ValueError or TypeError, we need to implement
|
||||
__reduce__ to override the default behaviour (instead of __getstate__/__setstate__)
|
||||
By default pickle protocol 2 calls `cls.__new__(cls, *args)`.
|
||||
Since we only use kwargs, we need a little constructor to change that.
|
||||
Note: the callable can't be a lambda as pickle looks in the namespace to find it
|
||||
"""
|
||||
return cls(**ctx)
|
||||
|
||||
|
||||
class PydanticErrorMixin:
|
||||
code: str
|
||||
msg_template: str
|
||||
|
||||
def __init__(self, **ctx: Any) -> None:
|
||||
self.__dict__ = ctx
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.msg_template.format(**self.__dict__)
|
||||
|
||||
def __reduce__(self) -> Tuple[Callable[..., 'PydanticErrorMixin'], Tuple[Type['PydanticErrorMixin'], 'DictStrAny']]:
|
||||
return cls_kwargs, (self.__class__, self.__dict__)
|
||||
|
||||
|
||||
class PydanticTypeError(PydanticErrorMixin, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class PydanticValueError(PydanticErrorMixin, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ConfigError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class MissingError(PydanticValueError):
|
||||
msg_template = 'field required'
|
||||
|
||||
|
||||
class ExtraError(PydanticValueError):
|
||||
msg_template = 'extra fields not permitted'
|
||||
|
||||
|
||||
class NoneIsNotAllowedError(PydanticTypeError):
|
||||
code = 'none.not_allowed'
|
||||
msg_template = 'none is not an allowed value'
|
||||
|
||||
|
||||
class NoneIsAllowedError(PydanticTypeError):
|
||||
code = 'none.allowed'
|
||||
msg_template = 'value is not none'
|
||||
|
||||
|
||||
class WrongConstantError(PydanticValueError):
|
||||
code = 'const'
|
||||
|
||||
def __str__(self) -> str:
|
||||
permitted = ', '.join(repr(v) for v in self.permitted) # type: ignore
|
||||
return f'unexpected value; permitted: {permitted}'
|
||||
|
||||
|
||||
class NotNoneError(PydanticTypeError):
|
||||
code = 'not_none'
|
||||
msg_template = 'value is not None'
|
||||
|
||||
|
||||
class BoolError(PydanticTypeError):
|
||||
msg_template = 'value could not be parsed to a boolean'
|
||||
|
||||
|
||||
class BytesError(PydanticTypeError):
|
||||
msg_template = 'byte type expected'
|
||||
|
||||
|
||||
class DictError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid dict'
|
||||
|
||||
|
||||
class EmailError(PydanticValueError):
|
||||
msg_template = 'value is not a valid email address'
|
||||
|
||||
|
||||
class UrlError(PydanticValueError):
|
||||
code = 'url'
|
||||
|
||||
|
||||
class UrlSchemeError(UrlError):
|
||||
code = 'url.scheme'
|
||||
msg_template = 'invalid or missing URL scheme'
|
||||
|
||||
|
||||
class UrlSchemePermittedError(UrlError):
|
||||
code = 'url.scheme'
|
||||
msg_template = 'URL scheme not permitted'
|
||||
|
||||
def __init__(self, allowed_schemes: Set[str]):
|
||||
super().__init__(allowed_schemes=allowed_schemes)
|
||||
|
||||
|
||||
class UrlUserInfoError(UrlError):
|
||||
code = 'url.userinfo'
|
||||
msg_template = 'userinfo required in URL but missing'
|
||||
|
||||
|
||||
class UrlHostError(UrlError):
|
||||
code = 'url.host'
|
||||
msg_template = 'URL host invalid'
|
||||
|
||||
|
||||
class UrlHostTldError(UrlError):
|
||||
code = 'url.host'
|
||||
msg_template = 'URL host invalid, top level domain required'
|
||||
|
||||
|
||||
class UrlPortError(UrlError):
|
||||
code = 'url.port'
|
||||
msg_template = 'URL port invalid, port cannot exceed 65535'
|
||||
|
||||
|
||||
class UrlExtraError(UrlError):
|
||||
code = 'url.extra'
|
||||
msg_template = 'URL invalid, extra characters found after valid URL: {extra!r}'
|
||||
|
||||
|
||||
class EnumMemberError(PydanticTypeError):
|
||||
code = 'enum'
|
||||
|
||||
def __str__(self) -> str:
|
||||
permitted = ', '.join(repr(v.value) for v in self.enum_values) # type: ignore
|
||||
return f'value is not a valid enumeration member; permitted: {permitted}'
|
||||
|
||||
|
||||
class IntegerError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid integer'
|
||||
|
||||
|
||||
class FloatError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid float'
|
||||
|
||||
|
||||
class PathError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid path'
|
||||
|
||||
|
||||
class _PathValueError(PydanticValueError):
|
||||
def __init__(self, *, path: Path) -> None:
|
||||
super().__init__(path=str(path))
|
||||
|
||||
|
||||
class PathNotExistsError(_PathValueError):
|
||||
code = 'path.not_exists'
|
||||
msg_template = 'file or directory at path "{path}" does not exist'
|
||||
|
||||
|
||||
class PathNotAFileError(_PathValueError):
|
||||
code = 'path.not_a_file'
|
||||
msg_template = 'path "{path}" does not point to a file'
|
||||
|
||||
|
||||
class PathNotADirectoryError(_PathValueError):
|
||||
code = 'path.not_a_directory'
|
||||
msg_template = 'path "{path}" does not point to a directory'
|
||||
|
||||
|
||||
class PyObjectError(PydanticTypeError):
|
||||
msg_template = 'ensure this value contains valid import path or valid callable: {error_message}'
|
||||
|
||||
|
||||
class SequenceError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid sequence'
|
||||
|
||||
|
||||
class IterableError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid iterable'
|
||||
|
||||
|
||||
class ListError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid list'
|
||||
|
||||
|
||||
class SetError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid set'
|
||||
|
||||
|
||||
class FrozenSetError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid frozenset'
|
||||
|
||||
|
||||
class DequeError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid deque'
|
||||
|
||||
|
||||
class TupleError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid tuple'
|
||||
|
||||
|
||||
class TupleLengthError(PydanticValueError):
|
||||
code = 'tuple.length'
|
||||
msg_template = 'wrong tuple length {actual_length}, expected {expected_length}'
|
||||
|
||||
def __init__(self, *, actual_length: int, expected_length: int) -> None:
|
||||
super().__init__(actual_length=actual_length, expected_length=expected_length)
|
||||
|
||||
|
||||
class ListMinLengthError(PydanticValueError):
|
||||
code = 'list.min_items'
|
||||
msg_template = 'ensure this value has at least {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class ListMaxLengthError(PydanticValueError):
|
||||
code = 'list.max_items'
|
||||
msg_template = 'ensure this value has at most {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class ListUniqueItemsError(PydanticValueError):
|
||||
code = 'list.unique_items'
|
||||
msg_template = 'the list has duplicated items'
|
||||
|
||||
|
||||
class SetMinLengthError(PydanticValueError):
|
||||
code = 'set.min_items'
|
||||
msg_template = 'ensure this value has at least {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class SetMaxLengthError(PydanticValueError):
|
||||
code = 'set.max_items'
|
||||
msg_template = 'ensure this value has at most {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class FrozenSetMinLengthError(PydanticValueError):
|
||||
code = 'frozenset.min_items'
|
||||
msg_template = 'ensure this value has at least {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class FrozenSetMaxLengthError(PydanticValueError):
|
||||
code = 'frozenset.max_items'
|
||||
msg_template = 'ensure this value has at most {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class AnyStrMinLengthError(PydanticValueError):
|
||||
code = 'any_str.min_length'
|
||||
msg_template = 'ensure this value has at least {limit_value} characters'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class AnyStrMaxLengthError(PydanticValueError):
|
||||
code = 'any_str.max_length'
|
||||
msg_template = 'ensure this value has at most {limit_value} characters'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class StrError(PydanticTypeError):
|
||||
msg_template = 'str type expected'
|
||||
|
||||
|
||||
class StrRegexError(PydanticValueError):
|
||||
code = 'str.regex'
|
||||
msg_template = 'string does not match regex "{pattern}"'
|
||||
|
||||
def __init__(self, *, pattern: str) -> None:
|
||||
super().__init__(pattern=pattern)
|
||||
|
||||
|
||||
class _NumberBoundError(PydanticValueError):
|
||||
def __init__(self, *, limit_value: Union[int, float, Decimal]) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class NumberNotGtError(_NumberBoundError):
|
||||
code = 'number.not_gt'
|
||||
msg_template = 'ensure this value is greater than {limit_value}'
|
||||
|
||||
|
||||
class NumberNotGeError(_NumberBoundError):
|
||||
code = 'number.not_ge'
|
||||
msg_template = 'ensure this value is greater than or equal to {limit_value}'
|
||||
|
||||
|
||||
class NumberNotLtError(_NumberBoundError):
|
||||
code = 'number.not_lt'
|
||||
msg_template = 'ensure this value is less than {limit_value}'
|
||||
|
||||
|
||||
class NumberNotLeError(_NumberBoundError):
|
||||
code = 'number.not_le'
|
||||
msg_template = 'ensure this value is less than or equal to {limit_value}'
|
||||
|
||||
|
||||
class NumberNotFiniteError(PydanticValueError):
|
||||
code = 'number.not_finite_number'
|
||||
msg_template = 'ensure this value is a finite number'
|
||||
|
||||
|
||||
class NumberNotMultipleError(PydanticValueError):
|
||||
code = 'number.not_multiple'
|
||||
msg_template = 'ensure this value is a multiple of {multiple_of}'
|
||||
|
||||
def __init__(self, *, multiple_of: Union[int, float, Decimal]) -> None:
|
||||
super().__init__(multiple_of=multiple_of)
|
||||
|
||||
|
||||
class DecimalError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid decimal'
|
||||
|
||||
|
||||
class DecimalIsNotFiniteError(PydanticValueError):
|
||||
code = 'decimal.not_finite'
|
||||
msg_template = 'value is not a valid decimal'
|
||||
|
||||
|
||||
class DecimalMaxDigitsError(PydanticValueError):
|
||||
code = 'decimal.max_digits'
|
||||
msg_template = 'ensure that there are no more than {max_digits} digits in total'
|
||||
|
||||
def __init__(self, *, max_digits: int) -> None:
|
||||
super().__init__(max_digits=max_digits)
|
||||
|
||||
|
||||
class DecimalMaxPlacesError(PydanticValueError):
|
||||
code = 'decimal.max_places'
|
||||
msg_template = 'ensure that there are no more than {decimal_places} decimal places'
|
||||
|
||||
def __init__(self, *, decimal_places: int) -> None:
|
||||
super().__init__(decimal_places=decimal_places)
|
||||
|
||||
|
||||
class DecimalWholeDigitsError(PydanticValueError):
|
||||
code = 'decimal.whole_digits'
|
||||
msg_template = 'ensure that there are no more than {whole_digits} digits before the decimal point'
|
||||
|
||||
def __init__(self, *, whole_digits: int) -> None:
|
||||
super().__init__(whole_digits=whole_digits)
|
||||
|
||||
|
||||
class DateTimeError(PydanticValueError):
|
||||
msg_template = 'invalid datetime format'
|
||||
|
||||
|
||||
class DateError(PydanticValueError):
|
||||
msg_template = 'invalid date format'
|
||||
|
||||
|
||||
class DateNotInThePastError(PydanticValueError):
|
||||
code = 'date.not_in_the_past'
|
||||
msg_template = 'date is not in the past'
|
||||
|
||||
|
||||
class DateNotInTheFutureError(PydanticValueError):
|
||||
code = 'date.not_in_the_future'
|
||||
msg_template = 'date is not in the future'
|
||||
|
||||
|
||||
class TimeError(PydanticValueError):
|
||||
msg_template = 'invalid time format'
|
||||
|
||||
|
||||
class DurationError(PydanticValueError):
|
||||
msg_template = 'invalid duration format'
|
||||
|
||||
|
||||
class HashableError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid hashable'
|
||||
|
||||
|
||||
class UUIDError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid uuid'
|
||||
|
||||
|
||||
class UUIDVersionError(PydanticValueError):
|
||||
code = 'uuid.version'
|
||||
msg_template = 'uuid version {required_version} expected'
|
||||
|
||||
def __init__(self, *, required_version: int) -> None:
|
||||
super().__init__(required_version=required_version)
|
||||
|
||||
|
||||
class ArbitraryTypeError(PydanticTypeError):
|
||||
code = 'arbitrary_type'
|
||||
msg_template = 'instance of {expected_arbitrary_type} expected'
|
||||
|
||||
def __init__(self, *, expected_arbitrary_type: Type[Any]) -> None:
|
||||
super().__init__(expected_arbitrary_type=display_as_type(expected_arbitrary_type))
|
||||
|
||||
|
||||
class ClassError(PydanticTypeError):
|
||||
code = 'class'
|
||||
msg_template = 'a class is expected'
|
||||
|
||||
|
||||
class SubclassError(PydanticTypeError):
|
||||
code = 'subclass'
|
||||
msg_template = 'subclass of {expected_class} expected'
|
||||
|
||||
def __init__(self, *, expected_class: Type[Any]) -> None:
|
||||
super().__init__(expected_class=display_as_type(expected_class))
|
||||
|
||||
|
||||
class JsonError(PydanticValueError):
|
||||
msg_template = 'Invalid JSON'
|
||||
|
||||
|
||||
class JsonTypeError(PydanticTypeError):
|
||||
code = 'json'
|
||||
msg_template = 'JSON object must be str, bytes or bytearray'
|
||||
|
||||
|
||||
class PatternError(PydanticValueError):
|
||||
code = 'regex_pattern'
|
||||
msg_template = 'Invalid regular expression'
|
||||
|
||||
|
||||
class DataclassTypeError(PydanticTypeError):
|
||||
code = 'dataclass'
|
||||
msg_template = 'instance of {class_name}, tuple or dict expected'
|
||||
|
||||
|
||||
class CallableError(PydanticTypeError):
|
||||
msg_template = '{value} is not callable'
|
||||
|
||||
|
||||
class EnumError(PydanticTypeError):
|
||||
code = 'enum_instance'
|
||||
msg_template = '{value} is not a valid Enum instance'
|
||||
|
||||
|
||||
class IntEnumError(PydanticTypeError):
|
||||
code = 'int_enum_instance'
|
||||
msg_template = '{value} is not a valid IntEnum instance'
|
||||
|
||||
|
||||
class IPvAnyAddressError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 or IPv6 address'
|
||||
|
||||
|
||||
class IPvAnyInterfaceError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 or IPv6 interface'
|
||||
|
||||
|
||||
class IPvAnyNetworkError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 or IPv6 network'
|
||||
|
||||
|
||||
class IPv4AddressError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 address'
|
||||
|
||||
|
||||
class IPv6AddressError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv6 address'
|
||||
|
||||
|
||||
class IPv4NetworkError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 network'
|
||||
|
||||
|
||||
class IPv6NetworkError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv6 network'
|
||||
|
||||
|
||||
class IPv4InterfaceError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 interface'
|
||||
|
||||
|
||||
class IPv6InterfaceError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv6 interface'
|
||||
|
||||
|
||||
class ColorError(PydanticValueError):
|
||||
msg_template = 'value is not a valid color: {reason}'
|
||||
|
||||
|
||||
class StrictBoolError(PydanticValueError):
|
||||
msg_template = 'value is not a valid boolean'
|
||||
|
||||
|
||||
class NotDigitError(PydanticValueError):
|
||||
code = 'payment_card_number.digits'
|
||||
msg_template = 'card number is not all digits'
|
||||
|
||||
|
||||
class LuhnValidationError(PydanticValueError):
|
||||
code = 'payment_card_number.luhn_check'
|
||||
msg_template = 'card number is not luhn valid'
|
||||
|
||||
|
||||
class InvalidLengthForBrand(PydanticValueError):
|
||||
code = 'payment_card_number.invalid_length_for_brand'
|
||||
msg_template = 'Length for a {brand} card must be {required_length}'
|
||||
|
||||
|
||||
class InvalidByteSize(PydanticValueError):
|
||||
msg_template = 'could not parse value and unit from byte string'
|
||||
|
||||
|
||||
class InvalidByteSizeUnit(PydanticValueError):
|
||||
msg_template = 'could not interpret byte unit: {unit}'
|
||||
|
||||
|
||||
class MissingDiscriminator(PydanticValueError):
|
||||
code = 'discriminated_union.missing_discriminator'
|
||||
msg_template = 'Discriminator {discriminator_key!r} is missing in value'
|
||||
|
||||
|
||||
class InvalidDiscriminator(PydanticValueError):
|
||||
code = 'discriminated_union.invalid_discriminator'
|
||||
msg_template = (
|
||||
'No match for discriminator {discriminator_key!r} and value {discriminator_value!r} '
|
||||
'(allowed values: {allowed_values})'
|
||||
)
|
||||
|
||||
def __init__(self, *, discriminator_key: str, discriminator_value: Any, allowed_values: Sequence[Any]) -> None:
|
||||
super().__init__(
|
||||
discriminator_key=discriminator_key,
|
||||
discriminator_value=discriminator_value,
|
||||
allowed_values=', '.join(map(repr, allowed_values)),
|
||||
)
|
||||
1253
venv/lib/python3.11/site-packages/pydantic/v1/fields.py
Normal file
1253
venv/lib/python3.11/site-packages/pydantic/v1/fields.py
Normal file
File diff suppressed because it is too large
Load Diff
400
venv/lib/python3.11/site-packages/pydantic/v1/generics.py
Normal file
400
venv/lib/python3.11/site-packages/pydantic/v1/generics.py
Normal file
@@ -0,0 +1,400 @@
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from weakref import WeakKeyDictionary, WeakValueDictionary
|
||||
|
||||
from typing_extensions import Annotated, Literal as ExtLiteral
|
||||
|
||||
from pydantic.v1.class_validators import gather_all_validators
|
||||
from pydantic.v1.fields import DeferredType
|
||||
from pydantic.v1.main import BaseModel, create_model
|
||||
from pydantic.v1.types import JsonWrapper
|
||||
from pydantic.v1.typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
|
||||
from pydantic.v1.utils import all_identical, lenient_issubclass
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import _UnionGenericAlias
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
|
||||
GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
|
||||
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
|
||||
|
||||
CacheKey = Tuple[Type[Any], Any, Tuple[Any, ...]]
|
||||
Parametrization = Mapping[TypeVarType, Type[Any]]
|
||||
|
||||
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
|
||||
# once they are no longer referenced by the caller.
|
||||
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
|
||||
GenericTypesCache = WeakValueDictionary[CacheKey, Type[BaseModel]]
|
||||
AssignedParameters = WeakKeyDictionary[Type[BaseModel], Parametrization]
|
||||
else:
|
||||
GenericTypesCache = WeakValueDictionary
|
||||
AssignedParameters = WeakKeyDictionary
|
||||
|
||||
# _generic_types_cache is a Mapping from __class_getitem__ arguments to the parametrized version of generic models.
|
||||
# This ensures multiple calls of e.g. A[B] return always the same class.
|
||||
_generic_types_cache = GenericTypesCache()
|
||||
|
||||
# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations
|
||||
# as captured during construction of the class (not instances).
|
||||
# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created,
|
||||
# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`.
|
||||
# (This information is only otherwise available after creation from the class name string).
|
||||
_assigned_parameters = AssignedParameters()
|
||||
|
||||
|
||||
class GenericModel(BaseModel):
|
||||
__slots__ = ()
|
||||
__concrete__: ClassVar[bool] = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with
|
||||
# `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of
|
||||
# `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below.
|
||||
__parameters__: ClassVar[Tuple[TypeVarType, ...]]
|
||||
|
||||
# Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings
|
||||
def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]:
|
||||
"""Instantiates a new class from a generic class `cls` and type variables `params`.
|
||||
|
||||
:param params: Tuple of types the class . Given a generic class
|
||||
`Model` with 2 type variables and a concrete model `Model[str, int]`,
|
||||
the value `(str, int)` would be passed to `params`.
|
||||
:return: New model class inheriting from `cls` with instantiated
|
||||
types described by `params`. If no parameters are given, `cls` is
|
||||
returned as is.
|
||||
|
||||
"""
|
||||
|
||||
def _cache_key(_params: Any) -> CacheKey:
|
||||
args = get_args(_params)
|
||||
# python returns a list for Callables, which is not hashable
|
||||
if len(args) == 2 and isinstance(args[0], list):
|
||||
args = (tuple(args[0]), args[1])
|
||||
return cls, _params, args
|
||||
|
||||
cached = _generic_types_cache.get(_cache_key(params))
|
||||
if cached is not None:
|
||||
return cached
|
||||
if cls.__concrete__ and Generic not in cls.__bases__:
|
||||
raise TypeError('Cannot parameterize a concrete instantiation of a generic model')
|
||||
if not isinstance(params, tuple):
|
||||
params = (params,)
|
||||
if cls is GenericModel and any(isinstance(param, TypeVar) for param in params):
|
||||
raise TypeError('Type parameters should be placed on typing.Generic, not GenericModel')
|
||||
if not hasattr(cls, '__parameters__'):
|
||||
raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized')
|
||||
|
||||
check_parameters_count(cls, params)
|
||||
# Build map from generic typevars to passed params
|
||||
typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params))
|
||||
if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map:
|
||||
return cls # if arguments are equal to parameters it's the same object
|
||||
|
||||
# Create new model with original model as parent inserting fields with DeferredType.
|
||||
model_name = cls.__concrete_name__(params)
|
||||
validators = gather_all_validators(cls)
|
||||
|
||||
type_hints = get_all_type_hints(cls).items()
|
||||
instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}
|
||||
|
||||
fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__}
|
||||
|
||||
model_module, called_globally = get_caller_frame_info()
|
||||
created_model = cast(
|
||||
Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes
|
||||
create_model(
|
||||
model_name,
|
||||
__module__=model_module or cls.__module__,
|
||||
__base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)),
|
||||
__config__=None,
|
||||
__validators__=validators,
|
||||
__cls_kwargs__=None,
|
||||
**fields,
|
||||
),
|
||||
)
|
||||
|
||||
_assigned_parameters[created_model] = typevars_map
|
||||
|
||||
if called_globally: # create global reference and therefore allow pickling
|
||||
object_by_reference = None
|
||||
reference_name = model_name
|
||||
reference_module_globals = sys.modules[created_model.__module__].__dict__
|
||||
while object_by_reference is not created_model:
|
||||
object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
|
||||
reference_name += '_'
|
||||
|
||||
created_model.Config = cls.Config
|
||||
|
||||
# Find any typevars that are still present in the model.
|
||||
# If none are left, the model is fully "concrete", otherwise the new
|
||||
# class is a generic class as well taking the found typevars as
|
||||
# parameters.
|
||||
new_params = tuple(
|
||||
{param: None for param in iter_contained_typevars(typevars_map.values())}
|
||||
) # use dict as ordered set
|
||||
created_model.__concrete__ = not new_params
|
||||
if new_params:
|
||||
created_model.__parameters__ = new_params
|
||||
|
||||
# Save created model in cache so we don't end up creating duplicate
|
||||
# models that should be identical.
|
||||
_generic_types_cache[_cache_key(params)] = created_model
|
||||
if len(params) == 1:
|
||||
_generic_types_cache[_cache_key(params[0])] = created_model
|
||||
|
||||
# Recursively walk class type hints and replace generic typevars
|
||||
# with concrete types that were passed.
|
||||
_prepare_model_fields(created_model, fields, instance_type_hints, typevars_map)
|
||||
|
||||
return created_model
|
||||
|
||||
@classmethod
|
||||
def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str:
|
||||
"""Compute class name for child classes.
|
||||
|
||||
:param params: Tuple of types the class . Given a generic class
|
||||
`Model` with 2 type variables and a concrete model `Model[str, int]`,
|
||||
the value `(str, int)` would be passed to `params`.
|
||||
:return: String representing a the new class where `params` are
|
||||
passed to `cls` as type variables.
|
||||
|
||||
This method can be overridden to achieve a custom naming scheme for GenericModels.
|
||||
"""
|
||||
param_names = [display_as_type(param) for param in params]
|
||||
params_component = ', '.join(param_names)
|
||||
return f'{cls.__name__}[{params_component}]'
|
||||
|
||||
@classmethod
|
||||
def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]:
|
||||
"""
|
||||
Returns unbound bases of cls parameterised to given type variables
|
||||
|
||||
:param typevars_map: Dictionary of type applications for binding subclasses.
|
||||
Given a generic class `Model` with 2 type variables [S, T]
|
||||
and a concrete model `Model[str, int]`,
|
||||
the value `{S: str, T: int}` would be passed to `typevars_map`.
|
||||
:return: an iterator of generic sub classes, parameterised by `typevars_map`
|
||||
and other assigned parameters of `cls`
|
||||
|
||||
e.g.:
|
||||
```
|
||||
class A(GenericModel, Generic[T]):
|
||||
...
|
||||
|
||||
class B(A[V], Generic[V]):
|
||||
...
|
||||
|
||||
assert A[int] in B.__parameterized_bases__({V: int})
|
||||
```
|
||||
"""
|
||||
|
||||
def build_base_model(
|
||||
base_model: Type[GenericModel], mapped_types: Parametrization
|
||||
) -> Iterator[Type[GenericModel]]:
|
||||
base_parameters = tuple(mapped_types[param] for param in base_model.__parameters__)
|
||||
parameterized_base = base_model.__class_getitem__(base_parameters)
|
||||
if parameterized_base is base_model or parameterized_base is cls:
|
||||
# Avoid duplication in MRO
|
||||
return
|
||||
yield parameterized_base
|
||||
|
||||
for base_model in cls.__bases__:
|
||||
if not issubclass(base_model, GenericModel):
|
||||
# not a class that can be meaningfully parameterized
|
||||
continue
|
||||
elif not getattr(base_model, '__parameters__', None):
|
||||
# base_model is "GenericModel" (and has no __parameters__)
|
||||
# or
|
||||
# base_model is already concrete, and will be included transitively via cls.
|
||||
continue
|
||||
elif cls in _assigned_parameters:
|
||||
if base_model in _assigned_parameters:
|
||||
# cls is partially parameterised but not from base_model
|
||||
# e.g. cls = B[S], base_model = A[S]
|
||||
# B[S][int] should subclass A[int], (and will be transitively via B[int])
|
||||
# but it's not viable to consistently subclass types with arbitrary construction
|
||||
# So don't attempt to include A[S][int]
|
||||
continue
|
||||
else: # base_model not in _assigned_parameters:
|
||||
# cls is partially parameterized, base_model is original generic
|
||||
# e.g. cls = B[str, T], base_model = B[S, T]
|
||||
# Need to determine the mapping for the base_model parameters
|
||||
mapped_types: Parametrization = {
|
||||
key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items()
|
||||
}
|
||||
yield from build_base_model(base_model, mapped_types)
|
||||
else:
|
||||
# cls is base generic, so base_class has a distinct base
|
||||
# can construct the Parameterised base model using typevars_map directly
|
||||
yield from build_base_model(base_model, typevars_map)
|
||||
|
||||
|
||||
def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
|
||||
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
|
||||
|
||||
:param type_: Any type, class or generic alias
|
||||
:param type_map: Mapping from `TypeVar` instance to concrete types.
|
||||
:return: New type representing the basic structure of `type_` with all
|
||||
`typevar_map` keys recursively replaced.
|
||||
|
||||
>>> replace_types(Tuple[str, Union[List[str], float]], {str: int})
|
||||
Tuple[int, Union[List[int], float]]
|
||||
|
||||
"""
|
||||
if not type_map:
|
||||
return type_
|
||||
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if origin_type is Annotated:
|
||||
annotated_type, *annotations = type_args
|
||||
return Annotated[replace_types(annotated_type, type_map), tuple(annotations)]
|
||||
|
||||
if (origin_type is ExtLiteral) or (sys.version_info >= (3, 8) and origin_type is Literal):
|
||||
return type_map.get(type_, type_)
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
if type_args:
|
||||
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
|
||||
if all_identical(type_args, resolved_type_args):
|
||||
# If all arguments are the same, there is no need to modify the
|
||||
# type or create a new object at all
|
||||
return type_
|
||||
if (
|
||||
origin_type is not None
|
||||
and isinstance(type_, typing_base)
|
||||
and not isinstance(origin_type, typing_base)
|
||||
and getattr(type_, '_name', None) is not None
|
||||
):
|
||||
# In python < 3.9 generic aliases don't exist so any of these like `list`,
|
||||
# `type` or `collections.abc.Callable` need to be translated.
|
||||
# See: https://www.python.org/dev/peps/pep-0585
|
||||
origin_type = getattr(typing, type_._name)
|
||||
assert origin_type is not None
|
||||
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
|
||||
# We also cannot use isinstance() since we have to compare types.
|
||||
if sys.version_info >= (3, 10) and origin_type is types.UnionType: # noqa: E721
|
||||
return _UnionGenericAlias(origin_type, resolved_type_args)
|
||||
return origin_type[resolved_type_args]
|
||||
|
||||
# We handle pydantic generic models separately as they don't have the same
|
||||
# semantics as "typing" classes or generic aliases
|
||||
if not origin_type and lenient_issubclass(type_, GenericModel) and not type_.__concrete__:
|
||||
type_args = type_.__parameters__
|
||||
resolved_type_args = tuple(replace_types(t, type_map) for t in type_args)
|
||||
if all_identical(type_args, resolved_type_args):
|
||||
return type_
|
||||
return type_[resolved_type_args]
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, (List, list)):
|
||||
resolved_list = list(replace_types(element, type_map) for element in type_)
|
||||
if all_identical(type_, resolved_list):
|
||||
return type_
|
||||
return resolved_list
|
||||
|
||||
# For JsonWrapperValue, need to handle its inner type to allow correct parsing
|
||||
# of generic Json arguments like Json[T]
|
||||
if not origin_type and lenient_issubclass(type_, JsonWrapper):
|
||||
type_.inner_type = replace_types(type_.inner_type, type_map)
|
||||
return type_
|
||||
|
||||
# If all else fails, we try to resolve the type directly and otherwise just
|
||||
# return the input with no modifications.
|
||||
new_type = type_map.get(type_, type_)
|
||||
# Convert string to ForwardRef
|
||||
if isinstance(new_type, str):
|
||||
return ForwardRef(new_type)
|
||||
else:
|
||||
return new_type
|
||||
|
||||
|
||||
def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None:
|
||||
actual = len(parameters)
|
||||
expected = len(cls.__parameters__)
|
||||
if actual != expected:
|
||||
description = 'many' if actual > expected else 'few'
|
||||
raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}')
|
||||
|
||||
|
||||
DictValues: Type[Any] = {}.values().__class__
|
||||
|
||||
|
||||
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
|
||||
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found."""
|
||||
if isinstance(v, TypeVar):
|
||||
yield v
|
||||
elif hasattr(v, '__parameters__') and not get_origin(v) and lenient_issubclass(v, GenericModel):
|
||||
yield from v.__parameters__
|
||||
elif isinstance(v, (DictValues, list)):
|
||||
for var in v:
|
||||
yield from iter_contained_typevars(var)
|
||||
else:
|
||||
args = get_args(v)
|
||||
for arg in args:
|
||||
yield from iter_contained_typevars(arg)
|
||||
|
||||
|
||||
def get_caller_frame_info() -> Tuple[Optional[str], bool]:
|
||||
"""
|
||||
Used inside a function to check whether it was called globally
|
||||
|
||||
Will only work against non-compiled code, therefore used only in pydantic.generics
|
||||
|
||||
:returns Tuple[module_name, called_globally]
|
||||
"""
|
||||
try:
|
||||
previous_caller_frame = sys._getframe(2)
|
||||
except ValueError as e:
|
||||
raise RuntimeError('This function must be used inside another function') from e
|
||||
except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
|
||||
return None, False
|
||||
frame_globals = previous_caller_frame.f_globals
|
||||
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
|
||||
|
||||
|
||||
def _prepare_model_fields(
|
||||
created_model: Type[GenericModel],
|
||||
fields: Mapping[str, Any],
|
||||
instance_type_hints: Mapping[str, type],
|
||||
typevars_map: Mapping[Any, type],
|
||||
) -> None:
|
||||
"""
|
||||
Replace DeferredType fields with concrete type hints and prepare them.
|
||||
"""
|
||||
|
||||
for key, field in created_model.__fields__.items():
|
||||
if key not in fields:
|
||||
assert field.type_.__class__ is not DeferredType
|
||||
# https://github.com/nedbat/coveragepy/issues/198
|
||||
continue # pragma: no cover
|
||||
|
||||
assert field.type_.__class__ is DeferredType, field.type_.__class__
|
||||
|
||||
field_type_hint = instance_type_hints[key]
|
||||
concrete_type = replace_types(field_type_hint, typevars_map)
|
||||
field.type_ = concrete_type
|
||||
field.outer_type_ = concrete_type
|
||||
field.prepare()
|
||||
created_model.__annotations__[key] = concrete_type
|
||||
112
venv/lib/python3.11/site-packages/pydantic/v1/json.py
Normal file
112
venv/lib/python3.11/site-packages/pydantic/v1/json.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import datetime
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from types import GeneratorType
|
||||
from typing import Any, Callable, Dict, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic.v1.color import Color
|
||||
from pydantic.v1.networks import NameEmail
|
||||
from pydantic.v1.types import SecretBytes, SecretStr
|
||||
|
||||
__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
|
||||
|
||||
|
||||
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
|
||||
return o.isoformat()
|
||||
|
||||
|
||||
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
||||
"""
|
||||
Encodes a Decimal as int of there's no exponent, otherwise float
|
||||
|
||||
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
|
||||
where a integer (but not int typed) is used. Encoding this as a float
|
||||
results in failed round-tripping between encode and parse.
|
||||
Our Id type is a prime example of this.
|
||||
|
||||
>>> decimal_encoder(Decimal("1.0"))
|
||||
1.0
|
||||
|
||||
>>> decimal_encoder(Decimal("1"))
|
||||
1
|
||||
"""
|
||||
if dec_value.as_tuple().exponent >= 0:
|
||||
return int(dec_value)
|
||||
else:
|
||||
return float(dec_value)
|
||||
|
||||
|
||||
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
||||
bytes: lambda o: o.decode(),
|
||||
Color: str,
|
||||
datetime.date: isoformat,
|
||||
datetime.datetime: isoformat,
|
||||
datetime.time: isoformat,
|
||||
datetime.timedelta: lambda td: td.total_seconds(),
|
||||
Decimal: decimal_encoder,
|
||||
Enum: lambda o: o.value,
|
||||
frozenset: list,
|
||||
deque: list,
|
||||
GeneratorType: list,
|
||||
IPv4Address: str,
|
||||
IPv4Interface: str,
|
||||
IPv4Network: str,
|
||||
IPv6Address: str,
|
||||
IPv6Interface: str,
|
||||
IPv6Network: str,
|
||||
NameEmail: str,
|
||||
Path: str,
|
||||
Pattern: lambda o: o.pattern,
|
||||
SecretBytes: str,
|
||||
SecretStr: str,
|
||||
set: list,
|
||||
UUID: str,
|
||||
}
|
||||
|
||||
|
||||
def pydantic_encoder(obj: Any) -> Any:
|
||||
from dataclasses import asdict, is_dataclass
|
||||
|
||||
from pydantic.v1.main import BaseModel
|
||||
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.dict()
|
||||
elif is_dataclass(obj):
|
||||
return asdict(obj)
|
||||
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
try:
|
||||
encoder = ENCODERS_BY_TYPE[base]
|
||||
except KeyError:
|
||||
continue
|
||||
return encoder(obj)
|
||||
else: # We have exited the for loop without finding a suitable encoder
|
||||
raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
|
||||
|
||||
|
||||
def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
try:
|
||||
encoder = type_encoders[base]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
return encoder(obj)
|
||||
else: # We have exited the for loop without finding a suitable encoder
|
||||
return pydantic_encoder(obj)
|
||||
|
||||
|
||||
def timedelta_isoformat(td: datetime.timedelta) -> str:
|
||||
"""
|
||||
ISO 8601 encoding for Python timedelta object.
|
||||
"""
|
||||
minutes, seconds = divmod(td.seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'
|
||||
1107
venv/lib/python3.11/site-packages/pydantic/v1/main.py
Normal file
1107
venv/lib/python3.11/site-packages/pydantic/v1/main.py
Normal file
File diff suppressed because it is too large
Load Diff
949
venv/lib/python3.11/site-packages/pydantic/v1/mypy.py
Normal file
949
venv/lib/python3.11/site-packages/pydantic/v1/mypy.py
Normal file
@@ -0,0 +1,949 @@
|
||||
import sys
|
||||
from configparser import ConfigParser
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type as TypingType, Union
|
||||
|
||||
from mypy.errorcodes import ErrorCode
|
||||
from mypy.nodes import (
|
||||
ARG_NAMED,
|
||||
ARG_NAMED_OPT,
|
||||
ARG_OPT,
|
||||
ARG_POS,
|
||||
ARG_STAR2,
|
||||
MDEF,
|
||||
Argument,
|
||||
AssignmentStmt,
|
||||
Block,
|
||||
CallExpr,
|
||||
ClassDef,
|
||||
Context,
|
||||
Decorator,
|
||||
EllipsisExpr,
|
||||
FuncBase,
|
||||
FuncDef,
|
||||
JsonDict,
|
||||
MemberExpr,
|
||||
NameExpr,
|
||||
PassStmt,
|
||||
PlaceholderNode,
|
||||
RefExpr,
|
||||
StrExpr,
|
||||
SymbolNode,
|
||||
SymbolTableNode,
|
||||
TempNode,
|
||||
TypeInfo,
|
||||
TypeVarExpr,
|
||||
Var,
|
||||
)
|
||||
from mypy.options import Options
|
||||
from mypy.plugin import (
|
||||
CheckerPluginInterface,
|
||||
ClassDefContext,
|
||||
FunctionContext,
|
||||
MethodContext,
|
||||
Plugin,
|
||||
ReportConfigContext,
|
||||
SemanticAnalyzerPluginInterface,
|
||||
)
|
||||
from mypy.plugins import dataclasses
|
||||
from mypy.semanal import set_callable_name # type: ignore
|
||||
from mypy.server.trigger import make_wildcard_trigger
|
||||
from mypy.types import (
|
||||
AnyType,
|
||||
CallableType,
|
||||
Instance,
|
||||
NoneType,
|
||||
Overloaded,
|
||||
ProperType,
|
||||
Type,
|
||||
TypeOfAny,
|
||||
TypeType,
|
||||
TypeVarId,
|
||||
TypeVarType,
|
||||
UnionType,
|
||||
get_proper_type,
|
||||
)
|
||||
from mypy.typevars import fill_typevars
|
||||
from mypy.util import get_unique_redefinition_name
|
||||
from mypy.version import __version__ as mypy_version
|
||||
|
||||
from pydantic.v1.utils import is_valid_field
|
||||
|
||||
try:
|
||||
from mypy.types import TypeVarDef # type: ignore[attr-defined]
|
||||
except ImportError: # pragma: no cover
|
||||
# Backward-compatible with TypeVarDef from Mypy 0.910.
|
||||
from mypy.types import TypeVarType as TypeVarDef
|
||||
|
||||
CONFIGFILE_KEY = 'pydantic-mypy'
|
||||
METADATA_KEY = 'pydantic-mypy-metadata'
|
||||
_NAMESPACE = __name__[:-5] # 'pydantic' in 1.10.X, 'pydantic.v1' in v2.X
|
||||
BASEMODEL_FULLNAME = f'{_NAMESPACE}.main.BaseModel'
|
||||
BASESETTINGS_FULLNAME = f'{_NAMESPACE}.env_settings.BaseSettings'
|
||||
MODEL_METACLASS_FULLNAME = f'{_NAMESPACE}.main.ModelMetaclass'
|
||||
FIELD_FULLNAME = f'{_NAMESPACE}.fields.Field'
|
||||
DATACLASS_FULLNAME = f'{_NAMESPACE}.dataclasses.dataclass'
|
||||
|
||||
|
||||
def parse_mypy_version(version: str) -> Tuple[int, ...]:
|
||||
return tuple(map(int, version.partition('+')[0].split('.')))
|
||||
|
||||
|
||||
MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version)
|
||||
BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__'
|
||||
|
||||
# Increment version if plugin changes and mypy caches should be invalidated
|
||||
__version__ = 2
|
||||
|
||||
|
||||
def plugin(version: str) -> 'TypingType[Plugin]':
|
||||
"""
|
||||
`version` is the mypy version string
|
||||
|
||||
We might want to use this to print a warning if the mypy version being used is
|
||||
newer, or especially older, than we expect (or need).
|
||||
"""
|
||||
return PydanticPlugin
|
||||
|
||||
|
||||
class PydanticPlugin(Plugin):
|
||||
def __init__(self, options: Options) -> None:
|
||||
self.plugin_config = PydanticPluginConfig(options)
|
||||
self._plugin_data = self.plugin_config.to_data()
|
||||
super().__init__(options)
|
||||
|
||||
def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefContext], None]]':
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
if sym and isinstance(sym.node, TypeInfo): # pragma: no branch
|
||||
# No branching may occur if the mypy cache has not been cleared
|
||||
if any(get_fullname(base) == BASEMODEL_FULLNAME for base in sym.node.mro):
|
||||
return self._pydantic_model_class_maker_callback
|
||||
return None
|
||||
|
||||
def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
if fullname == MODEL_METACLASS_FULLNAME:
|
||||
return self._pydantic_model_metaclass_marker_callback
|
||||
return None
|
||||
|
||||
def get_function_hook(self, fullname: str) -> 'Optional[Callable[[FunctionContext], Type]]':
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
if sym and sym.fullname == FIELD_FULLNAME:
|
||||
return self._pydantic_field_callback
|
||||
return None
|
||||
|
||||
def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]:
|
||||
if fullname.endswith('.from_orm'):
|
||||
return from_orm_callback
|
||||
return None
|
||||
|
||||
def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
"""Mark pydantic.dataclasses as dataclass.
|
||||
|
||||
Mypy version 1.1.1 added support for `@dataclass_transform` decorator.
|
||||
"""
|
||||
if fullname == DATACLASS_FULLNAME and MYPY_VERSION_TUPLE < (1, 1):
|
||||
return dataclasses.dataclass_class_maker_callback # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
def report_config_data(self, ctx: ReportConfigContext) -> Dict[str, Any]:
|
||||
"""Return all plugin config data.
|
||||
|
||||
Used by mypy to determine if cache needs to be discarded.
|
||||
"""
|
||||
return self._plugin_data
|
||||
|
||||
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None:
|
||||
transformer = PydanticModelTransformer(ctx, self.plugin_config)
|
||||
transformer.transform()
|
||||
|
||||
def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None:
|
||||
"""Reset dataclass_transform_spec attribute of ModelMetaclass.
|
||||
|
||||
Let the plugin handle it. This behavior can be disabled
|
||||
if 'debug_dataclass_transform' is set to True', for testing purposes.
|
||||
"""
|
||||
if self.plugin_config.debug_dataclass_transform:
|
||||
return
|
||||
info_metaclass = ctx.cls.info.declared_metaclass
|
||||
assert info_metaclass, "callback not passed from 'get_metaclass_hook'"
|
||||
if getattr(info_metaclass.type, 'dataclass_transform_spec', None):
|
||||
info_metaclass.type.dataclass_transform_spec = None # type: ignore[attr-defined]
|
||||
|
||||
def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type':
|
||||
"""
|
||||
Extract the type of the `default` argument from the Field function, and use it as the return type.
|
||||
|
||||
In particular:
|
||||
* Check whether the default and default_factory argument is specified.
|
||||
* Output an error if both are specified.
|
||||
* Retrieve the type of the argument which is specified, and use it as return type for the function.
|
||||
"""
|
||||
default_any_type = ctx.default_return_type
|
||||
|
||||
assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()'
|
||||
assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()'
|
||||
default_args = ctx.args[0]
|
||||
default_factory_args = ctx.args[1]
|
||||
|
||||
if default_args and default_factory_args:
|
||||
error_default_and_default_factory_specified(ctx.api, ctx.context)
|
||||
return default_any_type
|
||||
|
||||
if default_args:
|
||||
default_type = ctx.arg_types[0][0]
|
||||
default_arg = default_args[0]
|
||||
|
||||
# Fallback to default Any type if the field is required
|
||||
if not isinstance(default_arg, EllipsisExpr):
|
||||
return default_type
|
||||
|
||||
elif default_factory_args:
|
||||
default_factory_type = ctx.arg_types[1][0]
|
||||
|
||||
# Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter
|
||||
# Pydantic calls the default factory without any argument, so we retrieve the first item
|
||||
if isinstance(default_factory_type, Overloaded):
|
||||
if MYPY_VERSION_TUPLE > (0, 910):
|
||||
default_factory_type = default_factory_type.items[0]
|
||||
else:
|
||||
# Mypy0.910 exposes the items of overloaded types in a function
|
||||
default_factory_type = default_factory_type.items()[0] # type: ignore[operator]
|
||||
|
||||
if isinstance(default_factory_type, CallableType):
|
||||
ret_type = default_factory_type.ret_type
|
||||
# mypy doesn't think `ret_type` has `args`, you'd think mypy should know,
|
||||
# add this check in case it varies by version
|
||||
args = getattr(ret_type, 'args', None)
|
||||
if args:
|
||||
if all(isinstance(arg, TypeVarType) for arg in args):
|
||||
# Looks like the default factory is a type like `list` or `dict`, replace all args with `Any`
|
||||
ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined]
|
||||
return ret_type
|
||||
|
||||
return default_any_type
|
||||
|
||||
|
||||
class PydanticPluginConfig:
|
||||
__slots__ = (
|
||||
'init_forbid_extra',
|
||||
'init_typed',
|
||||
'warn_required_dynamic_aliases',
|
||||
'warn_untyped_fields',
|
||||
'debug_dataclass_transform',
|
||||
)
|
||||
init_forbid_extra: bool
|
||||
init_typed: bool
|
||||
warn_required_dynamic_aliases: bool
|
||||
warn_untyped_fields: bool
|
||||
debug_dataclass_transform: bool # undocumented
|
||||
|
||||
def __init__(self, options: Options) -> None:
|
||||
if options.config_file is None: # pragma: no cover
|
||||
return
|
||||
|
||||
toml_config = parse_toml(options.config_file)
|
||||
if toml_config is not None:
|
||||
config = toml_config.get('tool', {}).get('pydantic-mypy', {})
|
||||
for key in self.__slots__:
|
||||
setting = config.get(key, False)
|
||||
if not isinstance(setting, bool):
|
||||
raise ValueError(f'Configuration value must be a boolean for key: {key}')
|
||||
setattr(self, key, setting)
|
||||
else:
|
||||
plugin_config = ConfigParser()
|
||||
plugin_config.read(options.config_file)
|
||||
for key in self.__slots__:
|
||||
setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False)
|
||||
setattr(self, key, setting)
|
||||
|
||||
def to_data(self) -> Dict[str, Any]:
|
||||
return {key: getattr(self, key) for key in self.__slots__}
|
||||
|
||||
|
||||
def from_orm_callback(ctx: MethodContext) -> Type:
|
||||
"""
|
||||
Raise an error if orm_mode is not enabled
|
||||
"""
|
||||
model_type: Instance
|
||||
ctx_type = ctx.type
|
||||
if isinstance(ctx_type, TypeType):
|
||||
ctx_type = ctx_type.item
|
||||
if isinstance(ctx_type, CallableType) and isinstance(ctx_type.ret_type, Instance):
|
||||
model_type = ctx_type.ret_type # called on the class
|
||||
elif isinstance(ctx_type, Instance):
|
||||
model_type = ctx_type # called on an instance (unusual, but still valid)
|
||||
else: # pragma: no cover
|
||||
detail = f'ctx.type: {ctx_type} (of type {ctx_type.__class__.__name__})'
|
||||
error_unexpected_behavior(detail, ctx.api, ctx.context)
|
||||
return ctx.default_return_type
|
||||
pydantic_metadata = model_type.type.metadata.get(METADATA_KEY)
|
||||
if pydantic_metadata is None:
|
||||
return ctx.default_return_type
|
||||
orm_mode = pydantic_metadata.get('config', {}).get('orm_mode')
|
||||
if orm_mode is not True:
|
||||
error_from_orm(get_name(model_type.type), ctx.api, ctx.context)
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
class PydanticModelTransformer:
|
||||
tracked_config_fields: Set[str] = {
|
||||
'extra',
|
||||
'allow_mutation',
|
||||
'frozen',
|
||||
'orm_mode',
|
||||
'allow_population_by_field_name',
|
||||
'alias_generator',
|
||||
}
|
||||
|
||||
def __init__(self, ctx: ClassDefContext, plugin_config: PydanticPluginConfig) -> None:
|
||||
self._ctx = ctx
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
def transform(self) -> None:
|
||||
"""
|
||||
Configures the BaseModel subclass according to the plugin settings.
|
||||
|
||||
In particular:
|
||||
* determines the model config and fields,
|
||||
* adds a fields-aware signature for the initializer and construct methods
|
||||
* freezes the class if allow_mutation = False or frozen = True
|
||||
* stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses
|
||||
"""
|
||||
ctx = self._ctx
|
||||
info = ctx.cls.info
|
||||
|
||||
self.adjust_validator_signatures()
|
||||
config = self.collect_config()
|
||||
fields = self.collect_fields(config)
|
||||
is_settings = any(get_fullname(base) == BASESETTINGS_FULLNAME for base in info.mro[:-1])
|
||||
self.add_initializer(fields, config, is_settings)
|
||||
self.add_construct_method(fields)
|
||||
self.set_frozen(fields, frozen=config.allow_mutation is False or config.frozen is True)
|
||||
info.metadata[METADATA_KEY] = {
|
||||
'fields': {field.name: field.serialize() for field in fields},
|
||||
'config': config.set_values_dict(),
|
||||
}
|
||||
|
||||
def adjust_validator_signatures(self) -> None:
|
||||
"""When we decorate a function `f` with `pydantic.validator(...), mypy sees
|
||||
`f` as a regular method taking a `self` instance, even though pydantic
|
||||
internally wraps `f` with `classmethod` if necessary.
|
||||
|
||||
Teach mypy this by marking any function whose outermost decorator is a
|
||||
`validator()` call as a classmethod.
|
||||
"""
|
||||
for name, sym in self._ctx.cls.info.names.items():
|
||||
if isinstance(sym.node, Decorator):
|
||||
first_dec = sym.node.original_decorators[0]
|
||||
if (
|
||||
isinstance(first_dec, CallExpr)
|
||||
and isinstance(first_dec.callee, NameExpr)
|
||||
and first_dec.callee.fullname == f'{_NAMESPACE}.class_validators.validator'
|
||||
):
|
||||
sym.node.func.is_class = True
|
||||
|
||||
def collect_config(self) -> 'ModelConfigData':
|
||||
"""
|
||||
Collects the values of the config attributes that are used by the plugin, accounting for parent classes.
|
||||
"""
|
||||
ctx = self._ctx
|
||||
cls = ctx.cls
|
||||
config = ModelConfigData()
|
||||
for stmt in cls.defs.body:
|
||||
if not isinstance(stmt, ClassDef):
|
||||
continue
|
||||
if stmt.name == 'Config':
|
||||
for substmt in stmt.defs.body:
|
||||
if not isinstance(substmt, AssignmentStmt):
|
||||
continue
|
||||
config.update(self.get_config_update(substmt))
|
||||
if (
|
||||
config.has_alias_generator
|
||||
and not config.allow_population_by_field_name
|
||||
and self.plugin_config.warn_required_dynamic_aliases
|
||||
):
|
||||
error_required_dynamic_aliases(ctx.api, stmt)
|
||||
for info in cls.info.mro[1:]: # 0 is the current class
|
||||
if METADATA_KEY not in info.metadata:
|
||||
continue
|
||||
|
||||
# Each class depends on the set of fields in its ancestors
|
||||
ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info)))
|
||||
for name, value in info.metadata[METADATA_KEY]['config'].items():
|
||||
config.setdefault(name, value)
|
||||
return config
|
||||
|
||||
def collect_fields(self, model_config: 'ModelConfigData') -> List['PydanticModelField']:
|
||||
"""
|
||||
Collects the fields for the model, accounting for parent classes
|
||||
"""
|
||||
# First, collect fields belonging to the current class.
|
||||
ctx = self._ctx
|
||||
cls = self._ctx.cls
|
||||
fields = [] # type: List[PydanticModelField]
|
||||
known_fields = set() # type: Set[str]
|
||||
for stmt in cls.defs.body:
|
||||
if not isinstance(stmt, AssignmentStmt): # `and stmt.new_syntax` to require annotation
|
||||
continue
|
||||
|
||||
lhs = stmt.lvalues[0]
|
||||
if not isinstance(lhs, NameExpr) or not is_valid_field(lhs.name):
|
||||
continue
|
||||
|
||||
if not stmt.new_syntax and self.plugin_config.warn_untyped_fields:
|
||||
error_untyped_fields(ctx.api, stmt)
|
||||
|
||||
# if lhs.name == '__config__': # BaseConfig not well handled; I'm not sure why yet
|
||||
# continue
|
||||
|
||||
sym = cls.info.names.get(lhs.name)
|
||||
if sym is None: # pragma: no cover
|
||||
# This is likely due to a star import (see the dataclasses plugin for a more detailed explanation)
|
||||
# This is the same logic used in the dataclasses plugin
|
||||
continue
|
||||
|
||||
node = sym.node
|
||||
if isinstance(node, PlaceholderNode): # pragma: no cover
|
||||
# See the PlaceholderNode docstring for more detail about how this can occur
|
||||
# Basically, it is an edge case when dealing with complex import logic
|
||||
# This is the same logic used in the dataclasses plugin
|
||||
continue
|
||||
if not isinstance(node, Var): # pragma: no cover
|
||||
# Don't know if this edge case still happens with the `is_valid_field` check above
|
||||
# but better safe than sorry
|
||||
continue
|
||||
|
||||
# x: ClassVar[int] is ignored by dataclasses.
|
||||
if node.is_classvar:
|
||||
continue
|
||||
|
||||
is_required = self.get_is_required(cls, stmt, lhs)
|
||||
alias, has_dynamic_alias = self.get_alias_info(stmt)
|
||||
if (
|
||||
has_dynamic_alias
|
||||
and not model_config.allow_population_by_field_name
|
||||
and self.plugin_config.warn_required_dynamic_aliases
|
||||
):
|
||||
error_required_dynamic_aliases(ctx.api, stmt)
|
||||
fields.append(
|
||||
PydanticModelField(
|
||||
name=lhs.name,
|
||||
is_required=is_required,
|
||||
alias=alias,
|
||||
has_dynamic_alias=has_dynamic_alias,
|
||||
line=stmt.line,
|
||||
column=stmt.column,
|
||||
)
|
||||
)
|
||||
known_fields.add(lhs.name)
|
||||
all_fields = fields.copy()
|
||||
for info in cls.info.mro[1:]: # 0 is the current class, -2 is BaseModel, -1 is object
|
||||
if METADATA_KEY not in info.metadata:
|
||||
continue
|
||||
|
||||
superclass_fields = []
|
||||
# Each class depends on the set of fields in its ancestors
|
||||
ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info)))
|
||||
|
||||
for name, data in info.metadata[METADATA_KEY]['fields'].items():
|
||||
if name not in known_fields:
|
||||
field = PydanticModelField.deserialize(info, data)
|
||||
known_fields.add(name)
|
||||
superclass_fields.append(field)
|
||||
else:
|
||||
(field,) = (a for a in all_fields if a.name == name)
|
||||
all_fields.remove(field)
|
||||
superclass_fields.append(field)
|
||||
all_fields = superclass_fields + all_fields
|
||||
return all_fields
|
||||
|
||||
def add_initializer(self, fields: List['PydanticModelField'], config: 'ModelConfigData', is_settings: bool) -> None:
|
||||
"""
|
||||
Adds a fields-aware `__init__` method to the class.
|
||||
|
||||
The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings.
|
||||
"""
|
||||
ctx = self._ctx
|
||||
typed = self.plugin_config.init_typed
|
||||
use_alias = config.allow_population_by_field_name is not True
|
||||
force_all_optional = is_settings or bool(
|
||||
config.has_alias_generator and not config.allow_population_by_field_name
|
||||
)
|
||||
init_arguments = self.get_field_arguments(
|
||||
fields, typed=typed, force_all_optional=force_all_optional, use_alias=use_alias
|
||||
)
|
||||
if not self.should_init_forbid_extra(fields, config):
|
||||
var = Var('kwargs')
|
||||
init_arguments.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2))
|
||||
|
||||
if '__init__' not in ctx.cls.info.names:
|
||||
add_method(ctx, '__init__', init_arguments, NoneType())
|
||||
|
||||
def add_construct_method(self, fields: List['PydanticModelField']) -> None:
|
||||
"""
|
||||
Adds a fully typed `construct` classmethod to the class.
|
||||
|
||||
Similar to the fields-aware __init__ method, but always uses the field names (not aliases),
|
||||
and does not treat settings fields as optional.
|
||||
"""
|
||||
ctx = self._ctx
|
||||
set_str = ctx.api.named_type(f'{BUILTINS_NAME}.set', [ctx.api.named_type(f'{BUILTINS_NAME}.str')])
|
||||
optional_set_str = UnionType([set_str, NoneType()])
|
||||
fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT)
|
||||
construct_arguments = self.get_field_arguments(fields, typed=True, force_all_optional=False, use_alias=False)
|
||||
construct_arguments = [fields_set_argument] + construct_arguments
|
||||
|
||||
obj_type = ctx.api.named_type(f'{BUILTINS_NAME}.object')
|
||||
self_tvar_name = '_PydanticBaseModel' # Make sure it does not conflict with other names in the class
|
||||
tvar_fullname = ctx.cls.fullname + '.' + self_tvar_name
|
||||
if MYPY_VERSION_TUPLE >= (1, 4):
|
||||
tvd = TypeVarType(
|
||||
self_tvar_name,
|
||||
tvar_fullname,
|
||||
(
|
||||
TypeVarId(-1, namespace=ctx.cls.fullname + '.construct')
|
||||
if MYPY_VERSION_TUPLE >= (1, 11)
|
||||
else TypeVarId(-1)
|
||||
),
|
||||
[],
|
||||
obj_type,
|
||||
AnyType(TypeOfAny.from_omitted_generics), # type: ignore[arg-type]
|
||||
)
|
||||
self_tvar_expr = TypeVarExpr(
|
||||
self_tvar_name,
|
||||
tvar_fullname,
|
||||
[],
|
||||
obj_type,
|
||||
AnyType(TypeOfAny.from_omitted_generics), # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
tvd = TypeVarDef(self_tvar_name, tvar_fullname, -1, [], obj_type)
|
||||
self_tvar_expr = TypeVarExpr(self_tvar_name, tvar_fullname, [], obj_type)
|
||||
ctx.cls.info.names[self_tvar_name] = SymbolTableNode(MDEF, self_tvar_expr)
|
||||
|
||||
# Backward-compatible with TypeVarDef from Mypy 0.910.
|
||||
if isinstance(tvd, TypeVarType):
|
||||
self_type = tvd
|
||||
else:
|
||||
self_type = TypeVarType(tvd)
|
||||
|
||||
add_method(
|
||||
ctx,
|
||||
'construct',
|
||||
construct_arguments,
|
||||
return_type=self_type,
|
||||
self_type=self_type,
|
||||
tvar_def=tvd,
|
||||
is_classmethod=True,
|
||||
)
|
||||
|
||||
def set_frozen(self, fields: List['PydanticModelField'], frozen: bool) -> None:
|
||||
"""
|
||||
Marks all fields as properties so that attempts to set them trigger mypy errors.
|
||||
|
||||
This is the same approach used by the attrs and dataclasses plugins.
|
||||
"""
|
||||
ctx = self._ctx
|
||||
info = ctx.cls.info
|
||||
for field in fields:
|
||||
sym_node = info.names.get(field.name)
|
||||
if sym_node is not None:
|
||||
var = sym_node.node
|
||||
if isinstance(var, Var):
|
||||
var.is_property = frozen
|
||||
elif isinstance(var, PlaceholderNode) and not ctx.api.final_iteration:
|
||||
# See https://github.com/pydantic/pydantic/issues/5191 to hit this branch for test coverage
|
||||
ctx.api.defer()
|
||||
else: # pragma: no cover
|
||||
# I don't know whether it's possible to hit this branch, but I've added it for safety
|
||||
try:
|
||||
var_str = str(var)
|
||||
except TypeError:
|
||||
# This happens for PlaceholderNode; perhaps it will happen for other types in the future..
|
||||
var_str = repr(var)
|
||||
detail = f'sym_node.node: {var_str} (of type {var.__class__})'
|
||||
error_unexpected_behavior(detail, ctx.api, ctx.cls)
|
||||
else:
|
||||
var = field.to_var(info, use_alias=False)
|
||||
var.info = info
|
||||
var.is_property = frozen
|
||||
var._fullname = get_fullname(info) + '.' + get_name(var)
|
||||
info.names[get_name(var)] = SymbolTableNode(MDEF, var)
|
||||
|
||||
def get_config_update(self, substmt: AssignmentStmt) -> Optional['ModelConfigData']:
|
||||
"""
|
||||
Determines the config update due to a single statement in the Config class definition.
|
||||
|
||||
Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int)
|
||||
"""
|
||||
lhs = substmt.lvalues[0]
|
||||
if not (isinstance(lhs, NameExpr) and lhs.name in self.tracked_config_fields):
|
||||
return None
|
||||
if lhs.name == 'extra':
|
||||
if isinstance(substmt.rvalue, StrExpr):
|
||||
forbid_extra = substmt.rvalue.value == 'forbid'
|
||||
elif isinstance(substmt.rvalue, MemberExpr):
|
||||
forbid_extra = substmt.rvalue.name == 'forbid'
|
||||
else:
|
||||
error_invalid_config_value(lhs.name, self._ctx.api, substmt)
|
||||
return None
|
||||
return ModelConfigData(forbid_extra=forbid_extra)
|
||||
if lhs.name == 'alias_generator':
|
||||
has_alias_generator = True
|
||||
if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname == 'builtins.None':
|
||||
has_alias_generator = False
|
||||
return ModelConfigData(has_alias_generator=has_alias_generator)
|
||||
if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname in ('builtins.True', 'builtins.False'):
|
||||
return ModelConfigData(**{lhs.name: substmt.rvalue.fullname == 'builtins.True'})
|
||||
error_invalid_config_value(lhs.name, self._ctx.api, substmt)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_is_required(cls: ClassDef, stmt: AssignmentStmt, lhs: NameExpr) -> bool:
|
||||
"""
|
||||
Returns a boolean indicating whether the field defined in `stmt` is a required field.
|
||||
"""
|
||||
expr = stmt.rvalue
|
||||
if isinstance(expr, TempNode):
|
||||
# TempNode means annotation-only, so only non-required if Optional
|
||||
value_type = get_proper_type(cls.info[lhs.name].type)
|
||||
return not PydanticModelTransformer.type_has_implicit_default(value_type)
|
||||
if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME:
|
||||
# The "default value" is a call to `Field`; at this point, the field is
|
||||
# only required if default is Ellipsis (i.e., `field_name: Annotation = Field(...)`) or if default_factory
|
||||
# is specified.
|
||||
for arg, name in zip(expr.args, expr.arg_names):
|
||||
# If name is None, then this arg is the default because it is the only positional argument.
|
||||
if name is None or name == 'default':
|
||||
return arg.__class__ is EllipsisExpr
|
||||
if name == 'default_factory':
|
||||
return False
|
||||
# In this case, default and default_factory are not specified, so we need to look at the annotation
|
||||
value_type = get_proper_type(cls.info[lhs.name].type)
|
||||
return not PydanticModelTransformer.type_has_implicit_default(value_type)
|
||||
# Only required if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`)
|
||||
return isinstance(expr, EllipsisExpr)
|
||||
|
||||
@staticmethod
|
||||
def type_has_implicit_default(type_: Optional[ProperType]) -> bool:
|
||||
"""
|
||||
Returns True if the passed type will be given an implicit default value.
|
||||
|
||||
In pydantic v1, this is the case for Optional types and Any (with default value None).
|
||||
"""
|
||||
if isinstance(type_, AnyType):
|
||||
# Annotated as Any
|
||||
return True
|
||||
if isinstance(type_, UnionType) and any(
|
||||
isinstance(item, NoneType) or isinstance(item, AnyType) for item in type_.items
|
||||
):
|
||||
# Annotated as Optional, or otherwise having NoneType or AnyType in the union
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_alias_info(stmt: AssignmentStmt) -> Tuple[Optional[str], bool]:
|
||||
"""
|
||||
Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`.
|
||||
|
||||
`has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal.
|
||||
If `has_dynamic_alias` is True, `alias` will be None.
|
||||
"""
|
||||
expr = stmt.rvalue
|
||||
if isinstance(expr, TempNode):
|
||||
# TempNode means annotation-only
|
||||
return None, False
|
||||
|
||||
if not (
|
||||
isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME
|
||||
):
|
||||
# Assigned value is not a call to pydantic.fields.Field
|
||||
return None, False
|
||||
|
||||
for i, arg_name in enumerate(expr.arg_names):
|
||||
if arg_name != 'alias':
|
||||
continue
|
||||
arg = expr.args[i]
|
||||
if isinstance(arg, StrExpr):
|
||||
return arg.value, False
|
||||
else:
|
||||
return None, True
|
||||
return None, False
|
||||
|
||||
def get_field_arguments(
|
||||
self, fields: List['PydanticModelField'], typed: bool, force_all_optional: bool, use_alias: bool
|
||||
) -> List[Argument]:
|
||||
"""
|
||||
Helper function used during the construction of the `__init__` and `construct` method signatures.
|
||||
|
||||
Returns a list of mypy Argument instances for use in the generated signatures.
|
||||
"""
|
||||
info = self._ctx.cls.info
|
||||
arguments = [
|
||||
field.to_argument(info, typed=typed, force_optional=force_all_optional, use_alias=use_alias)
|
||||
for field in fields
|
||||
if not (use_alias and field.has_dynamic_alias)
|
||||
]
|
||||
return arguments
|
||||
|
||||
def should_init_forbid_extra(self, fields: List['PydanticModelField'], config: 'ModelConfigData') -> bool:
|
||||
"""
|
||||
Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature
|
||||
|
||||
We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to,
|
||||
*unless* a required dynamic alias is present (since then we can't determine a valid signature).
|
||||
"""
|
||||
if not config.allow_population_by_field_name:
|
||||
if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)):
|
||||
return False
|
||||
if config.forbid_extra:
|
||||
return True
|
||||
return self.plugin_config.init_forbid_extra
|
||||
|
||||
@staticmethod
|
||||
def is_dynamic_alias_present(fields: List['PydanticModelField'], has_alias_generator: bool) -> bool:
|
||||
"""
|
||||
Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be
|
||||
determined during static analysis.
|
||||
"""
|
||||
for field in fields:
|
||||
if field.has_dynamic_alias:
|
||||
return True
|
||||
if has_alias_generator:
|
||||
for field in fields:
|
||||
if field.alias is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class PydanticModelField:
|
||||
def __init__(
|
||||
self, name: str, is_required: bool, alias: Optional[str], has_dynamic_alias: bool, line: int, column: int
|
||||
):
|
||||
self.name = name
|
||||
self.is_required = is_required
|
||||
self.alias = alias
|
||||
self.has_dynamic_alias = has_dynamic_alias
|
||||
self.line = line
|
||||
self.column = column
|
||||
|
||||
def to_var(self, info: TypeInfo, use_alias: bool) -> Var:
|
||||
name = self.name
|
||||
if use_alias and self.alias is not None:
|
||||
name = self.alias
|
||||
return Var(name, info[self.name].type)
|
||||
|
||||
def to_argument(self, info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument:
|
||||
if typed and info[self.name].type is not None:
|
||||
type_annotation = info[self.name].type
|
||||
else:
|
||||
type_annotation = AnyType(TypeOfAny.explicit)
|
||||
return Argument(
|
||||
variable=self.to_var(info, use_alias),
|
||||
type_annotation=type_annotation,
|
||||
initializer=None,
|
||||
kind=ARG_NAMED_OPT if force_optional or not self.is_required else ARG_NAMED,
|
||||
)
|
||||
|
||||
def serialize(self) -> JsonDict:
|
||||
return self.__dict__
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'PydanticModelField':
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class ModelConfigData:
|
||||
def __init__(
|
||||
self,
|
||||
forbid_extra: Optional[bool] = None,
|
||||
allow_mutation: Optional[bool] = None,
|
||||
frozen: Optional[bool] = None,
|
||||
orm_mode: Optional[bool] = None,
|
||||
allow_population_by_field_name: Optional[bool] = None,
|
||||
has_alias_generator: Optional[bool] = None,
|
||||
):
|
||||
self.forbid_extra = forbid_extra
|
||||
self.allow_mutation = allow_mutation
|
||||
self.frozen = frozen
|
||||
self.orm_mode = orm_mode
|
||||
self.allow_population_by_field_name = allow_population_by_field_name
|
||||
self.has_alias_generator = has_alias_generator
|
||||
|
||||
def set_values_dict(self) -> Dict[str, Any]:
|
||||
return {k: v for k, v in self.__dict__.items() if v is not None}
|
||||
|
||||
def update(self, config: Optional['ModelConfigData']) -> None:
|
||||
if config is None:
|
||||
return
|
||||
for k, v in config.set_values_dict().items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def setdefault(self, key: str, value: Any) -> None:
|
||||
if getattr(self, key) is None:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_orm call', 'Pydantic')
|
||||
ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic')
|
||||
ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic')
|
||||
ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic')
|
||||
ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic')
|
||||
ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic')
|
||||
|
||||
|
||||
def error_from_orm(model_name: str, api: CheckerPluginInterface, context: Context) -> None:
|
||||
api.fail(f'"{model_name}" does not have orm_mode=True', context, code=ERROR_ORM)
|
||||
|
||||
|
||||
def error_invalid_config_value(name: str, api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||||
api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG)
|
||||
|
||||
|
||||
def error_required_dynamic_aliases(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||||
api.fail('Required dynamic aliases disallowed', context, code=ERROR_ALIAS)
|
||||
|
||||
|
||||
def error_unexpected_behavior(
|
||||
detail: str, api: Union[CheckerPluginInterface, SemanticAnalyzerPluginInterface], context: Context
|
||||
) -> None: # pragma: no cover
|
||||
# Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path
|
||||
link = 'https://github.com/pydantic/pydantic/issues/new/choose'
|
||||
full_message = f'The pydantic mypy plugin ran into unexpected behavior: {detail}\n'
|
||||
full_message += f'Please consider reporting this bug at {link} so we can try to fix it!'
|
||||
api.fail(full_message, context, code=ERROR_UNEXPECTED)
|
||||
|
||||
|
||||
def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||||
api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED)
|
||||
|
||||
|
||||
def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None:
|
||||
api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS)
|
||||
|
||||
|
||||
def add_method(
|
||||
ctx: ClassDefContext,
|
||||
name: str,
|
||||
args: List[Argument],
|
||||
return_type: Type,
|
||||
self_type: Optional[Type] = None,
|
||||
tvar_def: Optional[TypeVarDef] = None,
|
||||
is_classmethod: bool = False,
|
||||
is_new: bool = False,
|
||||
# is_staticmethod: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new method to a class.
|
||||
|
||||
This can be dropped if/when https://github.com/python/mypy/issues/7301 is merged
|
||||
"""
|
||||
info = ctx.cls.info
|
||||
|
||||
# First remove any previously generated methods with the same name
|
||||
# to avoid clashes and problems in the semantic analyzer.
|
||||
if name in info.names:
|
||||
sym = info.names[name]
|
||||
if sym.plugin_generated and isinstance(sym.node, FuncDef):
|
||||
ctx.cls.defs.body.remove(sym.node) # pragma: no cover
|
||||
|
||||
self_type = self_type or fill_typevars(info)
|
||||
if is_classmethod or is_new:
|
||||
first = [Argument(Var('_cls'), TypeType.make_normalized(self_type), None, ARG_POS)]
|
||||
# elif is_staticmethod:
|
||||
# first = []
|
||||
else:
|
||||
self_type = self_type or fill_typevars(info)
|
||||
first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)]
|
||||
args = first + args
|
||||
arg_types, arg_names, arg_kinds = [], [], []
|
||||
for arg in args:
|
||||
assert arg.type_annotation, 'All arguments must be fully typed.'
|
||||
arg_types.append(arg.type_annotation)
|
||||
arg_names.append(get_name(arg.variable))
|
||||
arg_kinds.append(arg.kind)
|
||||
|
||||
function_type = ctx.api.named_type(f'{BUILTINS_NAME}.function')
|
||||
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
|
||||
if tvar_def:
|
||||
signature.variables = [tvar_def]
|
||||
|
||||
func = FuncDef(name, args, Block([PassStmt()]))
|
||||
func.info = info
|
||||
func.type = set_callable_name(signature, func)
|
||||
func.is_class = is_classmethod
|
||||
# func.is_static = is_staticmethod
|
||||
func._fullname = get_fullname(info) + '.' + name
|
||||
func.line = info.line
|
||||
|
||||
# NOTE: we would like the plugin generated node to dominate, but we still
|
||||
# need to keep any existing definitions so they get semantically analyzed.
|
||||
if name in info.names:
|
||||
# Get a nice unique name instead.
|
||||
r_name = get_unique_redefinition_name(name, info.names)
|
||||
info.names[r_name] = info.names[name]
|
||||
|
||||
if is_classmethod: # or is_staticmethod:
|
||||
func.is_decorated = True
|
||||
v = Var(name, func.type)
|
||||
v.info = info
|
||||
v._fullname = func._fullname
|
||||
# if is_classmethod:
|
||||
v.is_classmethod = True
|
||||
dec = Decorator(func, [NameExpr('classmethod')], v)
|
||||
# else:
|
||||
# v.is_staticmethod = True
|
||||
# dec = Decorator(func, [NameExpr('staticmethod')], v)
|
||||
|
||||
dec.line = info.line
|
||||
sym = SymbolTableNode(MDEF, dec)
|
||||
else:
|
||||
sym = SymbolTableNode(MDEF, func)
|
||||
sym.plugin_generated = True
|
||||
|
||||
info.names[name] = sym
|
||||
info.defn.defs.body.append(func)
|
||||
|
||||
|
||||
def get_fullname(x: Union[FuncBase, SymbolNode]) -> str:
|
||||
"""
|
||||
Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped.
|
||||
"""
|
||||
fn = x.fullname
|
||||
if callable(fn): # pragma: no cover
|
||||
return fn()
|
||||
return fn
|
||||
|
||||
|
||||
def get_name(x: Union[FuncBase, SymbolNode]) -> str:
|
||||
"""
|
||||
Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped.
|
||||
"""
|
||||
fn = x.name
|
||||
if callable(fn): # pragma: no cover
|
||||
return fn()
|
||||
return fn
|
||||
|
||||
|
||||
def parse_toml(config_file: str) -> Optional[Dict[str, Any]]:
|
||||
if not config_file.endswith('.toml'):
|
||||
return None
|
||||
|
||||
read_mode = 'rb'
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib as toml_
|
||||
else:
|
||||
try:
|
||||
import tomli as toml_
|
||||
except ImportError:
|
||||
# older versions of mypy have toml as a dependency, not tomli
|
||||
read_mode = 'r'
|
||||
try:
|
||||
import toml as toml_ # type: ignore[no-redef]
|
||||
except ImportError: # pragma: no cover
|
||||
import warnings
|
||||
|
||||
warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.')
|
||||
return None
|
||||
|
||||
with open(config_file, read_mode) as rf:
|
||||
return toml_.load(rf) # type: ignore[arg-type]
|
||||
747
venv/lib/python3.11/site-packages/pydantic/v1/networks.py
Normal file
747
venv/lib/python3.11/site-packages/pydantic/v1/networks.py
Normal file
@@ -0,0 +1,747 @@
|
||||
import re
|
||||
from ipaddress import (
|
||||
IPv4Address,
|
||||
IPv4Interface,
|
||||
IPv4Network,
|
||||
IPv6Address,
|
||||
IPv6Interface,
|
||||
IPv6Network,
|
||||
_BaseAddress,
|
||||
_BaseNetwork,
|
||||
)
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Match,
|
||||
Optional,
|
||||
Pattern,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
no_type_check,
|
||||
)
|
||||
|
||||
from pydantic.v1 import errors
|
||||
from pydantic.v1.utils import Representation, update_not_none
|
||||
from pydantic.v1.validators import constr_length_validator, str_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import email_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.typing import AnyCallable
|
||||
|
||||
CallableGenerator = Generator[AnyCallable, None, None]
|
||||
|
||||
class Parts(TypedDict, total=False):
|
||||
scheme: str
|
||||
user: Optional[str]
|
||||
password: Optional[str]
|
||||
ipv4: Optional[str]
|
||||
ipv6: Optional[str]
|
||||
domain: Optional[str]
|
||||
port: Optional[str]
|
||||
path: Optional[str]
|
||||
query: Optional[str]
|
||||
fragment: Optional[str]
|
||||
|
||||
class HostParts(TypedDict, total=False):
|
||||
host: str
|
||||
tld: Optional[str]
|
||||
host_type: Optional[str]
|
||||
port: Optional[str]
|
||||
rebuild: bool
|
||||
|
||||
else:
|
||||
email_validator = None
|
||||
|
||||
class Parts(dict):
|
||||
pass
|
||||
|
||||
|
||||
NetworkType = Union[str, bytes, int, Tuple[Union[str, bytes, int], Union[str, int]]]
|
||||
|
||||
__all__ = [
|
||||
'AnyUrl',
|
||||
'AnyHttpUrl',
|
||||
'FileUrl',
|
||||
'HttpUrl',
|
||||
'stricturl',
|
||||
'EmailStr',
|
||||
'NameEmail',
|
||||
'IPvAnyAddress',
|
||||
'IPvAnyInterface',
|
||||
'IPvAnyNetwork',
|
||||
'PostgresDsn',
|
||||
'CockroachDsn',
|
||||
'AmqpDsn',
|
||||
'RedisDsn',
|
||||
'MongoDsn',
|
||||
'KafkaDsn',
|
||||
'validate_email',
|
||||
]
|
||||
|
||||
_url_regex_cache = None
|
||||
_multi_host_url_regex_cache = None
|
||||
_ascii_domain_regex_cache = None
|
||||
_int_domain_regex_cache = None
|
||||
_host_regex_cache = None
|
||||
|
||||
_host_regex = (
|
||||
r'(?:'
|
||||
r'(?P<ipv4>(?:\d{1,3}\.){3}\d{1,3})(?=$|[/:#?])|' # ipv4
|
||||
r'(?P<ipv6>\[[A-F0-9]*:[A-F0-9:]+\])(?=$|[/:#?])|' # ipv6
|
||||
r'(?P<domain>[^\s/:?#]+)' # domain, validation occurs later
|
||||
r')?'
|
||||
r'(?::(?P<port>\d+))?' # port
|
||||
)
|
||||
_scheme_regex = r'(?:(?P<scheme>[a-z][a-z0-9+\-.]+)://)?' # scheme https://tools.ietf.org/html/rfc3986#appendix-A
|
||||
_user_info_regex = r'(?:(?P<user>[^\s:/]*)(?::(?P<password>[^\s/]*))?@)?'
|
||||
_path_regex = r'(?P<path>/[^\s?#]*)?'
|
||||
_query_regex = r'(?:\?(?P<query>[^\s#]*))?'
|
||||
_fragment_regex = r'(?:#(?P<fragment>[^\s#]*))?'
|
||||
|
||||
|
||||
def url_regex() -> Pattern[str]:
|
||||
global _url_regex_cache
|
||||
if _url_regex_cache is None:
|
||||
_url_regex_cache = re.compile(
|
||||
rf'{_scheme_regex}{_user_info_regex}{_host_regex}{_path_regex}{_query_regex}{_fragment_regex}',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return _url_regex_cache
|
||||
|
||||
|
||||
def multi_host_url_regex() -> Pattern[str]:
|
||||
"""
|
||||
Compiled multi host url regex.
|
||||
|
||||
Additionally to `url_regex` it allows to match multiple hosts.
|
||||
E.g. host1.db.net,host2.db.net
|
||||
"""
|
||||
global _multi_host_url_regex_cache
|
||||
if _multi_host_url_regex_cache is None:
|
||||
_multi_host_url_regex_cache = re.compile(
|
||||
rf'{_scheme_regex}{_user_info_regex}'
|
||||
r'(?P<hosts>([^/]*))' # validation occurs later
|
||||
rf'{_path_regex}{_query_regex}{_fragment_regex}',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return _multi_host_url_regex_cache
|
||||
|
||||
|
||||
def ascii_domain_regex() -> Pattern[str]:
|
||||
global _ascii_domain_regex_cache
|
||||
if _ascii_domain_regex_cache is None:
|
||||
ascii_chunk = r'[_0-9a-z](?:[-_0-9a-z]{0,61}[_0-9a-z])?'
|
||||
ascii_domain_ending = r'(?P<tld>\.[a-z]{2,63})?\.?'
|
||||
_ascii_domain_regex_cache = re.compile(
|
||||
fr'(?:{ascii_chunk}\.)*?{ascii_chunk}{ascii_domain_ending}', re.IGNORECASE
|
||||
)
|
||||
return _ascii_domain_regex_cache
|
||||
|
||||
|
||||
def int_domain_regex() -> Pattern[str]:
|
||||
global _int_domain_regex_cache
|
||||
if _int_domain_regex_cache is None:
|
||||
int_chunk = r'[_0-9a-\U00040000](?:[-_0-9a-\U00040000]{0,61}[_0-9a-\U00040000])?'
|
||||
int_domain_ending = r'(?P<tld>(\.[^\W\d_]{2,63})|(\.(?:xn--)[_0-9a-z-]{2,63}))?\.?'
|
||||
_int_domain_regex_cache = re.compile(fr'(?:{int_chunk}\.)*?{int_chunk}{int_domain_ending}', re.IGNORECASE)
|
||||
return _int_domain_regex_cache
|
||||
|
||||
|
||||
def host_regex() -> Pattern[str]:
|
||||
global _host_regex_cache
|
||||
if _host_regex_cache is None:
|
||||
_host_regex_cache = re.compile(
|
||||
_host_regex,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return _host_regex_cache
|
||||
|
||||
|
||||
class AnyUrl(str):
|
||||
strip_whitespace = True
|
||||
min_length = 1
|
||||
max_length = 2**16
|
||||
allowed_schemes: Optional[Collection[str]] = None
|
||||
tld_required: bool = False
|
||||
user_required: bool = False
|
||||
host_required: bool = True
|
||||
hidden_parts: Set[str] = set()
|
||||
|
||||
__slots__ = ('scheme', 'user', 'password', 'host', 'tld', 'host_type', 'port', 'path', 'query', 'fragment')
|
||||
|
||||
@no_type_check
|
||||
def __new__(cls, url: Optional[str], **kwargs) -> object:
|
||||
return str.__new__(cls, cls.build(**kwargs) if url is None else url)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
scheme: str,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
tld: Optional[str] = None,
|
||||
host_type: str = 'domain',
|
||||
port: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
fragment: Optional[str] = None,
|
||||
) -> None:
|
||||
str.__init__(url)
|
||||
self.scheme = scheme
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.host = host
|
||||
self.tld = tld
|
||||
self.host_type = host_type
|
||||
self.port = port
|
||||
self.path = path
|
||||
self.query = query
|
||||
self.fragment = fragment
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
*,
|
||||
scheme: str,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
host: str,
|
||||
port: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
fragment: Optional[str] = None,
|
||||
**_kwargs: str,
|
||||
) -> str:
|
||||
parts = Parts(
|
||||
scheme=scheme,
|
||||
user=user,
|
||||
password=password,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
query=query,
|
||||
fragment=fragment,
|
||||
**_kwargs, # type: ignore[misc]
|
||||
)
|
||||
|
||||
url = scheme + '://'
|
||||
if user:
|
||||
url += user
|
||||
if password:
|
||||
url += ':' + password
|
||||
if user or password:
|
||||
url += '@'
|
||||
url += host
|
||||
if port and ('port' not in cls.hidden_parts or cls.get_default_parts(parts).get('port') != port):
|
||||
url += ':' + port
|
||||
if path:
|
||||
url += path
|
||||
if query:
|
||||
url += '?' + query
|
||||
if fragment:
|
||||
url += '#' + fragment
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length, format='uri')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any, field: 'ModelField', config: 'BaseConfig') -> 'AnyUrl':
|
||||
if value.__class__ == cls:
|
||||
return value
|
||||
value = str_validator(value)
|
||||
if cls.strip_whitespace:
|
||||
value = value.strip()
|
||||
url: str = cast(str, constr_length_validator(value, field, config))
|
||||
|
||||
m = cls._match_url(url)
|
||||
# the regex should always match, if it doesn't please report with details of the URL tried
|
||||
assert m, 'URL regex failed unexpectedly'
|
||||
|
||||
original_parts = cast('Parts', m.groupdict())
|
||||
parts = cls.apply_default_parts(original_parts)
|
||||
parts = cls.validate_parts(parts)
|
||||
|
||||
if m.end() != len(url):
|
||||
raise errors.UrlExtraError(extra=url[m.end() :])
|
||||
|
||||
return cls._build_url(m, url, parts)
|
||||
|
||||
@classmethod
|
||||
def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'AnyUrl':
|
||||
"""
|
||||
Validate hosts and build the AnyUrl object. Split from `validate` so this method
|
||||
can be altered in `MultiHostDsn`.
|
||||
"""
|
||||
host, tld, host_type, rebuild = cls.validate_host(parts)
|
||||
|
||||
return cls(
|
||||
None if rebuild else url,
|
||||
scheme=parts['scheme'],
|
||||
user=parts['user'],
|
||||
password=parts['password'],
|
||||
host=host,
|
||||
tld=tld,
|
||||
host_type=host_type,
|
||||
port=parts['port'],
|
||||
path=parts['path'],
|
||||
query=parts['query'],
|
||||
fragment=parts['fragment'],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _match_url(url: str) -> Optional[Match[str]]:
|
||||
return url_regex().match(url)
|
||||
|
||||
@staticmethod
|
||||
def _validate_port(port: Optional[str]) -> None:
|
||||
if port is not None and int(port) > 65_535:
|
||||
raise errors.UrlPortError()
|
||||
|
||||
@classmethod
|
||||
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
|
||||
"""
|
||||
A method used to validate parts of a URL.
|
||||
Could be overridden to set default values for parts if missing
|
||||
"""
|
||||
scheme = parts['scheme']
|
||||
if scheme is None:
|
||||
raise errors.UrlSchemeError()
|
||||
|
||||
if cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes:
|
||||
raise errors.UrlSchemePermittedError(set(cls.allowed_schemes))
|
||||
|
||||
if validate_port:
|
||||
cls._validate_port(parts['port'])
|
||||
|
||||
user = parts['user']
|
||||
if cls.user_required and user is None:
|
||||
raise errors.UrlUserInfoError()
|
||||
|
||||
return parts
|
||||
|
||||
@classmethod
|
||||
def validate_host(cls, parts: 'Parts') -> Tuple[str, Optional[str], str, bool]:
|
||||
tld, host_type, rebuild = None, None, False
|
||||
for f in ('domain', 'ipv4', 'ipv6'):
|
||||
host = parts[f] # type: ignore[literal-required]
|
||||
if host:
|
||||
host_type = f
|
||||
break
|
||||
|
||||
if host is None:
|
||||
if cls.host_required:
|
||||
raise errors.UrlHostError()
|
||||
elif host_type == 'domain':
|
||||
is_international = False
|
||||
d = ascii_domain_regex().fullmatch(host)
|
||||
if d is None:
|
||||
d = int_domain_regex().fullmatch(host)
|
||||
if d is None:
|
||||
raise errors.UrlHostError()
|
||||
is_international = True
|
||||
|
||||
tld = d.group('tld')
|
||||
if tld is None and not is_international:
|
||||
d = int_domain_regex().fullmatch(host)
|
||||
assert d is not None
|
||||
tld = d.group('tld')
|
||||
is_international = True
|
||||
|
||||
if tld is not None:
|
||||
tld = tld[1:]
|
||||
elif cls.tld_required:
|
||||
raise errors.UrlHostTldError()
|
||||
|
||||
if is_international:
|
||||
host_type = 'int_domain'
|
||||
rebuild = True
|
||||
host = host.encode('idna').decode('ascii')
|
||||
if tld is not None:
|
||||
tld = tld.encode('idna').decode('ascii')
|
||||
|
||||
return host, tld, host_type, rebuild # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def apply_default_parts(cls, parts: 'Parts') -> 'Parts':
|
||||
for key, value in cls.get_default_parts(parts).items():
|
||||
if not parts[key]: # type: ignore[literal-required]
|
||||
parts[key] = value # type: ignore[literal-required]
|
||||
return parts
|
||||
|
||||
def __repr__(self) -> str:
|
||||
extra = ', '.join(f'{n}={getattr(self, n)!r}' for n in self.__slots__ if getattr(self, n) is not None)
|
||||
return f'{self.__class__.__name__}({super().__repr__()}, {extra})'
|
||||
|
||||
|
||||
class AnyHttpUrl(AnyUrl):
|
||||
allowed_schemes = {'http', 'https'}
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class HttpUrl(AnyHttpUrl):
|
||||
tld_required = True
|
||||
# https://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url-in-different-browsers
|
||||
max_length = 2083
|
||||
hidden_parts = {'port'}
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {'port': '80' if parts['scheme'] == 'http' else '443'}
|
||||
|
||||
|
||||
class FileUrl(AnyUrl):
|
||||
allowed_schemes = {'file'}
|
||||
host_required = False
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class MultiHostDsn(AnyUrl):
|
||||
__slots__ = AnyUrl.__slots__ + ('hosts',)
|
||||
|
||||
def __init__(self, *args: Any, hosts: Optional[List['HostParts']] = None, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hosts = hosts
|
||||
|
||||
@staticmethod
|
||||
def _match_url(url: str) -> Optional[Match[str]]:
|
||||
return multi_host_url_regex().match(url)
|
||||
|
||||
@classmethod
|
||||
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
|
||||
return super().validate_parts(parts, validate_port=False)
|
||||
|
||||
@classmethod
|
||||
def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'MultiHostDsn':
|
||||
hosts_parts: List['HostParts'] = []
|
||||
host_re = host_regex()
|
||||
for host in m.groupdict()['hosts'].split(','):
|
||||
d: Parts = host_re.match(host).groupdict() # type: ignore
|
||||
host, tld, host_type, rebuild = cls.validate_host(d)
|
||||
port = d.get('port')
|
||||
cls._validate_port(port)
|
||||
hosts_parts.append(
|
||||
{
|
||||
'host': host,
|
||||
'host_type': host_type,
|
||||
'tld': tld,
|
||||
'rebuild': rebuild,
|
||||
'port': port,
|
||||
}
|
||||
)
|
||||
|
||||
if len(hosts_parts) > 1:
|
||||
return cls(
|
||||
None if any([hp['rebuild'] for hp in hosts_parts]) else url,
|
||||
scheme=parts['scheme'],
|
||||
user=parts['user'],
|
||||
password=parts['password'],
|
||||
path=parts['path'],
|
||||
query=parts['query'],
|
||||
fragment=parts['fragment'],
|
||||
host_type=None,
|
||||
hosts=hosts_parts,
|
||||
)
|
||||
else:
|
||||
# backwards compatibility with single host
|
||||
host_part = hosts_parts[0]
|
||||
return cls(
|
||||
None if host_part['rebuild'] else url,
|
||||
scheme=parts['scheme'],
|
||||
user=parts['user'],
|
||||
password=parts['password'],
|
||||
host=host_part['host'],
|
||||
tld=host_part['tld'],
|
||||
host_type=host_part['host_type'],
|
||||
port=host_part.get('port'),
|
||||
path=parts['path'],
|
||||
query=parts['query'],
|
||||
fragment=parts['fragment'],
|
||||
)
|
||||
|
||||
|
||||
class PostgresDsn(MultiHostDsn):
|
||||
allowed_schemes = {
|
||||
'postgres',
|
||||
'postgresql',
|
||||
'postgresql+asyncpg',
|
||||
'postgresql+pg8000',
|
||||
'postgresql+psycopg',
|
||||
'postgresql+psycopg2',
|
||||
'postgresql+psycopg2cffi',
|
||||
'postgresql+py-postgresql',
|
||||
'postgresql+pygresql',
|
||||
}
|
||||
user_required = True
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class CockroachDsn(AnyUrl):
|
||||
allowed_schemes = {
|
||||
'cockroachdb',
|
||||
'cockroachdb+psycopg2',
|
||||
'cockroachdb+asyncpg',
|
||||
}
|
||||
user_required = True
|
||||
|
||||
|
||||
class AmqpDsn(AnyUrl):
|
||||
allowed_schemes = {'amqp', 'amqps'}
|
||||
host_required = False
|
||||
|
||||
|
||||
class RedisDsn(AnyUrl):
|
||||
__slots__ = ()
|
||||
allowed_schemes = {'redis', 'rediss'}
|
||||
host_required = False
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {
|
||||
'domain': 'localhost' if not (parts['ipv4'] or parts['ipv6']) else '',
|
||||
'port': '6379',
|
||||
'path': '/0',
|
||||
}
|
||||
|
||||
|
||||
class MongoDsn(AnyUrl):
|
||||
allowed_schemes = {'mongodb'}
|
||||
|
||||
# TODO: Needed to generic "Parts" for "Replica Set", "Sharded Cluster", and other mongodb deployment modes
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {
|
||||
'port': '27017',
|
||||
}
|
||||
|
||||
|
||||
class KafkaDsn(AnyUrl):
|
||||
allowed_schemes = {'kafka'}
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {
|
||||
'domain': 'localhost',
|
||||
'port': '9092',
|
||||
}
|
||||
|
||||
|
||||
def stricturl(
|
||||
*,
|
||||
strip_whitespace: bool = True,
|
||||
min_length: int = 1,
|
||||
max_length: int = 2**16,
|
||||
tld_required: bool = True,
|
||||
host_required: bool = True,
|
||||
allowed_schemes: Optional[Collection[str]] = None,
|
||||
) -> Type[AnyUrl]:
|
||||
# use kwargs then define conf in a dict to aid with IDE type hinting
|
||||
namespace = dict(
|
||||
strip_whitespace=strip_whitespace,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
tld_required=tld_required,
|
||||
host_required=host_required,
|
||||
allowed_schemes=allowed_schemes,
|
||||
)
|
||||
return type('UrlValue', (AnyUrl,), namespace)
|
||||
|
||||
|
||||
def import_email_validator() -> None:
|
||||
global email_validator
|
||||
try:
|
||||
import email_validator
|
||||
except ImportError as e:
|
||||
raise ImportError('email-validator is not installed, run `pip install pydantic[email]`') from e
|
||||
|
||||
|
||||
class EmailStr(str):
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='email')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
# included here and below so the error happens straight away
|
||||
import_email_validator()
|
||||
|
||||
yield str_validator
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Union[str]) -> str:
|
||||
return validate_email(value)[1]
|
||||
|
||||
|
||||
class NameEmail(Representation):
|
||||
__slots__ = 'name', 'email'
|
||||
|
||||
def __init__(self, name: str, email: str):
|
||||
self.name = name
|
||||
self.email = email
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, NameEmail) and (self.name, self.email) == (other.name, other.email)
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='name-email')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
import_email_validator()
|
||||
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any) -> 'NameEmail':
|
||||
if value.__class__ == cls:
|
||||
return value
|
||||
value = str_validator(value)
|
||||
return cls(*validate_email(value))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{self.name} <{self.email}>'
|
||||
|
||||
|
||||
class IPvAnyAddress(_BaseAddress):
|
||||
__slots__ = ()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='ipvanyaddress')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Union[str, bytes, int]) -> Union[IPv4Address, IPv6Address]:
|
||||
try:
|
||||
return IPv4Address(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Address(value)
|
||||
except ValueError:
|
||||
raise errors.IPvAnyAddressError()
|
||||
|
||||
|
||||
class IPvAnyInterface(_BaseAddress):
|
||||
__slots__ = ()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='ipvanyinterface')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: NetworkType) -> Union[IPv4Interface, IPv6Interface]:
|
||||
try:
|
||||
return IPv4Interface(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Interface(value)
|
||||
except ValueError:
|
||||
raise errors.IPvAnyInterfaceError()
|
||||
|
||||
|
||||
class IPvAnyNetwork(_BaseNetwork): # type: ignore
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='ipvanynetwork')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: NetworkType) -> Union[IPv4Network, IPv6Network]:
|
||||
# Assume IP Network is defined with a default value for ``strict`` argument.
|
||||
# Define your own class if you want to specify network address check strictness.
|
||||
try:
|
||||
return IPv4Network(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Network(value)
|
||||
except ValueError:
|
||||
raise errors.IPvAnyNetworkError()
|
||||
|
||||
|
||||
pretty_email_regex = re.compile(r'([\w ]*?) *<(.*)> *')
|
||||
MAX_EMAIL_LENGTH = 2048
|
||||
"""Maximum length for an email.
|
||||
A somewhat arbitrary but very generous number compared to what is allowed by most implementations.
|
||||
"""
|
||||
|
||||
|
||||
def validate_email(value: Union[str]) -> Tuple[str, str]:
|
||||
"""
|
||||
Email address validation using https://pypi.org/project/email-validator/
|
||||
Notes:
|
||||
* raw ip address (literal) domain parts are not allowed.
|
||||
* "John Doe <local_part@domain.com>" style "pretty" email addresses are processed
|
||||
* spaces are striped from the beginning and end of addresses but no error is raised
|
||||
"""
|
||||
if email_validator is None:
|
||||
import_email_validator()
|
||||
|
||||
if len(value) > MAX_EMAIL_LENGTH:
|
||||
raise errors.EmailError()
|
||||
|
||||
m = pretty_email_regex.fullmatch(value)
|
||||
name: Union[str, None] = None
|
||||
if m:
|
||||
name, value = m.groups()
|
||||
email = value.strip()
|
||||
try:
|
||||
parts = email_validator.validate_email(email, check_deliverability=False)
|
||||
except email_validator.EmailNotValidError as e:
|
||||
raise errors.EmailError from e
|
||||
|
||||
if hasattr(parts, 'normalized'):
|
||||
# email-validator >= 2
|
||||
email = parts.normalized
|
||||
assert email is not None
|
||||
name = name or parts.local_part
|
||||
return name, email
|
||||
else:
|
||||
# email-validator >1, <2
|
||||
at_index = email.index('@')
|
||||
local_part = email[:at_index] # RFC 5321, local part must be case-sensitive.
|
||||
global_part = email[at_index:].lower()
|
||||
|
||||
return name or local_part, local_part + global_part
|
||||
66
venv/lib/python3.11/site-packages/pydantic/v1/parse.py
Normal file
66
venv/lib/python3.11/site-packages/pydantic/v1/parse.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import json
|
||||
import pickle
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic.v1.types import StrBytes
|
||||
|
||||
|
||||
class Protocol(str, Enum):
|
||||
json = 'json'
|
||||
pickle = 'pickle'
|
||||
|
||||
|
||||
def load_str_bytes(
|
||||
b: StrBytes,
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
if proto is None and content_type:
|
||||
if content_type.endswith(('json', 'javascript')):
|
||||
pass
|
||||
elif allow_pickle and content_type.endswith('pickle'):
|
||||
proto = Protocol.pickle
|
||||
else:
|
||||
raise TypeError(f'Unknown content-type: {content_type}')
|
||||
|
||||
proto = proto or Protocol.json
|
||||
|
||||
if proto == Protocol.json:
|
||||
if isinstance(b, bytes):
|
||||
b = b.decode(encoding)
|
||||
return json_loads(b)
|
||||
elif proto == Protocol.pickle:
|
||||
if not allow_pickle:
|
||||
raise RuntimeError('Trying to decode with pickle with allow_pickle=False')
|
||||
bb = b if isinstance(b, bytes) else b.encode()
|
||||
return pickle.loads(bb)
|
||||
else:
|
||||
raise TypeError(f'Unknown protocol: {proto}')
|
||||
|
||||
|
||||
def load_file(
|
||||
path: Union[str, Path],
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
path = Path(path)
|
||||
b = path.read_bytes()
|
||||
if content_type is None:
|
||||
if path.suffix in ('.js', '.json'):
|
||||
proto = Protocol.json
|
||||
elif path.suffix == '.pkl':
|
||||
proto = Protocol.pickle
|
||||
|
||||
return load_str_bytes(
|
||||
b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads
|
||||
)
|
||||
1163
venv/lib/python3.11/site-packages/pydantic/v1/schema.py
Normal file
1163
venv/lib/python3.11/site-packages/pydantic/v1/schema.py
Normal file
File diff suppressed because it is too large
Load Diff
92
venv/lib/python3.11/site-packages/pydantic/v1/tools.py
Normal file
92
venv/lib/python3.11/site-packages/pydantic/v1/tools.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union
|
||||
|
||||
from pydantic.v1.parse import Protocol, load_file, load_str_bytes
|
||||
from pydantic.v1.types import StrBytes
|
||||
from pydantic.v1.typing import display_as_type
|
||||
|
||||
__all__ = ('parse_file_as', 'parse_obj_as', 'parse_raw_as', 'schema_of', 'schema_json_of')
|
||||
|
||||
NameFactory = Union[str, Callable[[Type[Any]], str]]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import DictStrAny
|
||||
|
||||
|
||||
def _generate_parsing_type_name(type_: Any) -> str:
|
||||
return f'ParsingModel[{display_as_type(type_)}]'
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any:
|
||||
from pydantic.v1.main import create_model
|
||||
|
||||
if type_name is None:
|
||||
type_name = _generate_parsing_type_name
|
||||
if not isinstance(type_name, str):
|
||||
type_name = type_name(type_)
|
||||
return create_model(type_name, __root__=(type_, ...))
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def parse_obj_as(type_: Type[T], obj: Any, *, type_name: Optional[NameFactory] = None) -> T:
|
||||
model_type = _get_parsing_type(type_, type_name=type_name) # type: ignore[arg-type]
|
||||
return model_type(__root__=obj).__root__
|
||||
|
||||
|
||||
def parse_file_as(
|
||||
type_: Type[T],
|
||||
path: Union[str, Path],
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
type_name: Optional[NameFactory] = None,
|
||||
) -> T:
|
||||
obj = load_file(
|
||||
path,
|
||||
proto=proto,
|
||||
content_type=content_type,
|
||||
encoding=encoding,
|
||||
allow_pickle=allow_pickle,
|
||||
json_loads=json_loads,
|
||||
)
|
||||
return parse_obj_as(type_, obj, type_name=type_name)
|
||||
|
||||
|
||||
def parse_raw_as(
|
||||
type_: Type[T],
|
||||
b: StrBytes,
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
type_name: Optional[NameFactory] = None,
|
||||
) -> T:
|
||||
obj = load_str_bytes(
|
||||
b,
|
||||
proto=proto,
|
||||
content_type=content_type,
|
||||
encoding=encoding,
|
||||
allow_pickle=allow_pickle,
|
||||
json_loads=json_loads,
|
||||
)
|
||||
return parse_obj_as(type_, obj, type_name=type_name)
|
||||
|
||||
|
||||
def schema_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_kwargs: Any) -> 'DictStrAny':
|
||||
"""Generate a JSON schema (as dict) for the passed model or dynamically generated one"""
|
||||
return _get_parsing_type(type_, type_name=title).schema(**schema_kwargs)
|
||||
|
||||
|
||||
def schema_json_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_json_kwargs: Any) -> str:
|
||||
"""Generate a JSON schema (as JSON) for the passed model or dynamically generated one"""
|
||||
return _get_parsing_type(type_, type_name=title).schema_json(**schema_json_kwargs)
|
||||
1205
venv/lib/python3.11/site-packages/pydantic/v1/types.py
Normal file
1205
venv/lib/python3.11/site-packages/pydantic/v1/types.py
Normal file
File diff suppressed because it is too large
Load Diff
608
venv/lib/python3.11/site-packages/pydantic/v1/typing.py
Normal file
608
venv/lib/python3.11/site-packages/pydantic/v1/typing.py
Normal file
@@ -0,0 +1,608 @@
|
||||
import sys
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from os import PathLike
|
||||
from typing import ( # type: ignore
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable as TypingCallable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
NewType,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
_eval_type,
|
||||
cast,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typing_extensions import (
|
||||
Annotated,
|
||||
Final,
|
||||
Literal,
|
||||
NotRequired as TypedDictNotRequired,
|
||||
Required as TypedDictRequired,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import _TypingBase as typing_base # type: ignore
|
||||
except ImportError:
|
||||
from typing import _Final as typing_base # type: ignore
|
||||
|
||||
try:
|
||||
from typing import GenericAlias as TypingGenericAlias # type: ignore
|
||||
except ImportError:
|
||||
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
TypingGenericAlias = ()
|
||||
|
||||
try:
|
||||
from types import UnionType as TypesUnionType # type: ignore
|
||||
except ImportError:
|
||||
# python < 3.10 does not have UnionType (str | int, byte | bool and so on)
|
||||
TypesUnionType = ()
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
return type_._evaluate(globalns, localns)
|
||||
|
||||
else:
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
# Even though it is the right signature for python 3.9, mypy complains with
|
||||
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
|
||||
# Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid:
|
||||
# TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard'
|
||||
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
# Ensure we always get all the whole `Annotated` hint, not just the annotated type.
|
||||
# For 3.7 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`,
|
||||
# so it already returns the full annotation
|
||||
get_all_type_hints = get_type_hints
|
||||
|
||||
else:
|
||||
|
||||
def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any:
|
||||
return get_type_hints(obj, globalns, localns, include_extras=True)
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
AnyCallable = TypingCallable[..., Any]
|
||||
NoArgAnyCallable = TypingCallable[[], Any]
|
||||
|
||||
# workaround for https://github.com/python/mypy/issues/9496
|
||||
AnyArgTCallable = TypingCallable[..., _T]
|
||||
|
||||
|
||||
# Annotated[...] is implemented by returning an instance of one of these classes, depending on
|
||||
# python/typing_extensions version.
|
||||
AnnotatedTypeNames = {'AnnotatedMeta', '_AnnotatedAlias'}
|
||||
|
||||
|
||||
LITERAL_TYPES: Set[Any] = {Literal}
|
||||
if hasattr(typing, 'Literal'):
|
||||
LITERAL_TYPES.add(typing.Literal)
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
|
||||
def get_origin(t: Type[Any]) -> Optional[Type[Any]]:
|
||||
if type(t).__name__ in AnnotatedTypeNames:
|
||||
# weirdly this is a runtime requirement, as well as for mypy
|
||||
return cast(Type[Any], Annotated)
|
||||
return getattr(t, '__origin__', None)
|
||||
|
||||
else:
|
||||
from typing import get_origin as _typing_get_origin
|
||||
|
||||
def get_origin(tp: Type[Any]) -> Optional[Type[Any]]:
|
||||
"""
|
||||
We can't directly use `typing.get_origin` since we need a fallback to support
|
||||
custom generic classes like `ConstrainedList`
|
||||
It should be useless once https://github.com/cython/cython/issues/3537 is
|
||||
solved and https://github.com/pydantic/pydantic/pull/1753 is merged.
|
||||
"""
|
||||
if type(tp).__name__ in AnnotatedTypeNames:
|
||||
return cast(Type[Any], Annotated) # mypy complains about _SpecialForm
|
||||
return _typing_get_origin(tp) or getattr(tp, '__origin__', None)
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing import _GenericAlias
|
||||
|
||||
def get_args(t: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""Compatibility version of get_args for python 3.7.
|
||||
|
||||
Mostly compatible with the python 3.8 `typing` module version
|
||||
and able to handle almost all use cases.
|
||||
"""
|
||||
if type(t).__name__ in AnnotatedTypeNames:
|
||||
return t.__args__ + t.__metadata__
|
||||
if isinstance(t, _GenericAlias):
|
||||
res = t.__args__
|
||||
if t.__origin__ is Callable and res and res[0] is not Ellipsis:
|
||||
res = (list(res[:-1]), res[-1])
|
||||
return res
|
||||
return getattr(t, '__args__', ())
|
||||
|
||||
else:
|
||||
from typing import get_args as _typing_get_args
|
||||
|
||||
def _generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""
|
||||
In python 3.9, `typing.Dict`, `typing.List`, ...
|
||||
do have an empty `__args__` by default (instead of the generic ~T for example).
|
||||
In order to still support `Dict` for example and consider it as `Dict[Any, Any]`,
|
||||
we retrieve the `_nparams` value that tells us how many parameters it needs.
|
||||
"""
|
||||
if hasattr(tp, '_nparams'):
|
||||
return (Any,) * tp._nparams
|
||||
# Special case for `tuple[()]`, which used to return ((),) with `typing.Tuple`
|
||||
# in python 3.10- but now returns () for `tuple` and `Tuple`.
|
||||
# This will probably be clarified in pydantic v2
|
||||
try:
|
||||
if tp == Tuple[()] or sys.version_info >= (3, 9) and tp == tuple[()]: # type: ignore[misc]
|
||||
return ((),)
|
||||
# there is a TypeError when compiled with cython
|
||||
except TypeError: # pragma: no cover
|
||||
pass
|
||||
return ()
|
||||
|
||||
def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""Get type arguments with all substitutions performed.
|
||||
|
||||
For unions, basic simplifications used by Union constructor are performed.
|
||||
Examples::
|
||||
get_args(Dict[str, int]) == (str, int)
|
||||
get_args(int) == ()
|
||||
get_args(Union[int, Union[T, int], str][int]) == (int, str)
|
||||
get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
|
||||
get_args(Callable[[], T][int]) == ([], int)
|
||||
"""
|
||||
if type(tp).__name__ in AnnotatedTypeNames:
|
||||
return tp.__args__ + tp.__metadata__
|
||||
# the fallback is needed for the same reasons as `get_origin` (see above)
|
||||
return _typing_get_args(tp) or getattr(tp, '__args__', ()) or _generic_get_args(tp)
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
def convert_generics(tp: Type[Any]) -> Type[Any]:
|
||||
"""Python 3.9 and older only supports generics from `typing` module.
|
||||
They convert strings to ForwardRef automatically.
|
||||
|
||||
Examples::
|
||||
typing.List['Hero'] == typing.List[ForwardRef('Hero')]
|
||||
"""
|
||||
return tp
|
||||
|
||||
else:
|
||||
from typing import _UnionGenericAlias # type: ignore
|
||||
|
||||
from typing_extensions import _AnnotatedAlias
|
||||
|
||||
def convert_generics(tp: Type[Any]) -> Type[Any]:
|
||||
"""
|
||||
Recursively searches for `str` type hints and replaces them with ForwardRef.
|
||||
|
||||
Examples::
|
||||
convert_generics(list['Hero']) == list[ForwardRef('Hero')]
|
||||
convert_generics(dict['Hero', 'Team']) == dict[ForwardRef('Hero'), ForwardRef('Team')]
|
||||
convert_generics(typing.Dict['Hero', 'Team']) == typing.Dict[ForwardRef('Hero'), ForwardRef('Team')]
|
||||
convert_generics(list[str | 'Hero'] | int) == list[str | ForwardRef('Hero')] | int
|
||||
"""
|
||||
origin = get_origin(tp)
|
||||
if not origin or not hasattr(tp, '__args__'):
|
||||
return tp
|
||||
|
||||
args = get_args(tp)
|
||||
|
||||
# typing.Annotated needs special treatment
|
||||
if origin is Annotated:
|
||||
return _AnnotatedAlias(convert_generics(args[0]), args[1:])
|
||||
|
||||
# recursively replace `str` instances inside of `GenericAlias` with `ForwardRef(arg)`
|
||||
converted = tuple(
|
||||
ForwardRef(arg) if isinstance(arg, str) and isinstance(tp, TypingGenericAlias) else convert_generics(arg)
|
||||
for arg in args
|
||||
)
|
||||
|
||||
if converted == args:
|
||||
return tp
|
||||
elif isinstance(tp, TypingGenericAlias):
|
||||
return TypingGenericAlias(origin, converted)
|
||||
elif isinstance(tp, TypesUnionType):
|
||||
# recreate types.UnionType (PEP604, Python >= 3.10)
|
||||
return _UnionGenericAlias(origin, converted)
|
||||
else:
|
||||
try:
|
||||
setattr(tp, '__args__', converted)
|
||||
except AttributeError:
|
||||
pass
|
||||
return tp
|
||||
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
|
||||
def is_union(tp: Optional[Type[Any]]) -> bool:
|
||||
return tp is Union
|
||||
|
||||
WithArgsTypes = (TypingGenericAlias,)
|
||||
|
||||
else:
|
||||
import types
|
||||
import typing
|
||||
|
||||
def is_union(tp: Optional[Type[Any]]) -> bool:
|
||||
return tp is Union or tp is types.UnionType # noqa: E721
|
||||
|
||||
WithArgsTypes = (typing._GenericAlias, types.GenericAlias, types.UnionType)
|
||||
|
||||
|
||||
StrPath = Union[str, PathLike]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.fields import ModelField
|
||||
|
||||
TupleGenerator = Generator[Tuple[str, Any], None, None]
|
||||
DictStrAny = Dict[str, Any]
|
||||
DictAny = Dict[Any, Any]
|
||||
SetStr = Set[str]
|
||||
ListStr = List[str]
|
||||
IntStr = Union[int, str]
|
||||
AbstractSetIntStr = AbstractSet[IntStr]
|
||||
DictIntStrAny = Dict[IntStr, Any]
|
||||
MappingIntStrAny = Mapping[IntStr, Any]
|
||||
CallableGenerator = Generator[AnyCallable, None, None]
|
||||
ReprArgs = Sequence[Tuple[Optional[str], Any]]
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
AnyClassMethod = classmethod[Any]
|
||||
else:
|
||||
# classmethod[TargetType, CallableParamSpecType, CallableReturnType]
|
||||
AnyClassMethod = classmethod[Any, Any, Any]
|
||||
|
||||
__all__ = (
|
||||
'AnyCallable',
|
||||
'NoArgAnyCallable',
|
||||
'NoneType',
|
||||
'is_none_type',
|
||||
'display_as_type',
|
||||
'resolve_annotations',
|
||||
'is_callable_type',
|
||||
'is_literal_type',
|
||||
'all_literal_values',
|
||||
'is_namedtuple',
|
||||
'is_typeddict',
|
||||
'is_typeddict_special',
|
||||
'is_new_type',
|
||||
'new_type_supertype',
|
||||
'is_classvar',
|
||||
'is_finalvar',
|
||||
'update_field_forward_refs',
|
||||
'update_model_forward_refs',
|
||||
'TupleGenerator',
|
||||
'DictStrAny',
|
||||
'DictAny',
|
||||
'SetStr',
|
||||
'ListStr',
|
||||
'IntStr',
|
||||
'AbstractSetIntStr',
|
||||
'DictIntStrAny',
|
||||
'CallableGenerator',
|
||||
'ReprArgs',
|
||||
'AnyClassMethod',
|
||||
'CallableGenerator',
|
||||
'WithArgsTypes',
|
||||
'get_args',
|
||||
'get_origin',
|
||||
'get_sub_types',
|
||||
'typing_base',
|
||||
'get_all_type_hints',
|
||||
'is_union',
|
||||
'StrPath',
|
||||
'MappingIntStrAny',
|
||||
)
|
||||
|
||||
|
||||
NoneType = None.__class__
|
||||
|
||||
|
||||
NONE_TYPES: Tuple[Any, Any, Any] = (None, NoneType, Literal[None])
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
# Even though this implementation is slower, we need it for python 3.7:
|
||||
# In python 3.7 "Literal" is not a builtin type and uses a different
|
||||
# mechanism.
|
||||
# for this reason `Literal[None] is Literal[None]` evaluates to `False`,
|
||||
# breaking the faster implementation used for the other python versions.
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
return type_ in NONE_TYPES
|
||||
|
||||
elif sys.version_info[:2] == (3, 8):
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
for none_type in NONE_TYPES:
|
||||
if type_ is none_type:
|
||||
return True
|
||||
# With python 3.8, specifically 3.8.10, Literal "is" check sare very flakey
|
||||
# can change on very subtle changes like use of types in other modules,
|
||||
# hopefully this check avoids that issue.
|
||||
if is_literal_type(type_): # pragma: no cover
|
||||
return all_literal_values(type_) == (None,)
|
||||
return False
|
||||
|
||||
else:
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
return type_ in NONE_TYPES
|
||||
|
||||
|
||||
def display_as_type(v: Type[Any]) -> str:
|
||||
if not isinstance(v, typing_base) and not isinstance(v, WithArgsTypes) and not isinstance(v, type):
|
||||
v = v.__class__
|
||||
|
||||
if is_union(get_origin(v)):
|
||||
return f'Union[{", ".join(map(display_as_type, get_args(v)))}]'
|
||||
|
||||
if isinstance(v, WithArgsTypes):
|
||||
# Generic alias are constructs like `list[int]`
|
||||
return str(v).replace('typing.', '')
|
||||
|
||||
try:
|
||||
return v.__name__
|
||||
except AttributeError:
|
||||
# happens with typing objects
|
||||
return str(v).replace('typing.', '')
|
||||
|
||||
|
||||
def resolve_annotations(raw_annotations: Dict[str, Type[Any]], module_name: Optional[str]) -> Dict[str, Type[Any]]:
|
||||
"""
|
||||
Partially taken from typing.get_type_hints.
|
||||
|
||||
Resolve string or ForwardRef annotations into type objects if possible.
|
||||
"""
|
||||
base_globals: Optional[Dict[str, Any]] = None
|
||||
if module_name:
|
||||
try:
|
||||
module = sys.modules[module_name]
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
pass
|
||||
else:
|
||||
base_globals = module.__dict__
|
||||
|
||||
annotations = {}
|
||||
for name, value in raw_annotations.items():
|
||||
if isinstance(value, str):
|
||||
if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1):
|
||||
value = ForwardRef(value, is_argument=False, is_class=True)
|
||||
else:
|
||||
value = ForwardRef(value, is_argument=False)
|
||||
try:
|
||||
if sys.version_info >= (3, 13):
|
||||
value = _eval_type(value, base_globals, None, type_params=())
|
||||
else:
|
||||
value = _eval_type(value, base_globals, None)
|
||||
except NameError:
|
||||
# this is ok, it can be fixed with update_forward_refs
|
||||
pass
|
||||
annotations[name] = value
|
||||
return annotations
|
||||
|
||||
|
||||
def is_callable_type(type_: Type[Any]) -> bool:
|
||||
return type_ is Callable or get_origin(type_) is Callable
|
||||
|
||||
|
||||
def is_literal_type(type_: Type[Any]) -> bool:
|
||||
return Literal is not None and get_origin(type_) in LITERAL_TYPES
|
||||
|
||||
|
||||
def literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
|
||||
return get_args(type_)
|
||||
|
||||
|
||||
def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""
|
||||
This method is used to retrieve all Literal values as
|
||||
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
|
||||
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`
|
||||
"""
|
||||
if not is_literal_type(type_):
|
||||
return (type_,)
|
||||
|
||||
values = literal_values(type_)
|
||||
return tuple(x for value in values for x in all_literal_values(value))
|
||||
|
||||
|
||||
def is_namedtuple(type_: Type[Any]) -> bool:
|
||||
"""
|
||||
Check if a given class is a named tuple.
|
||||
It can be either a `typing.NamedTuple` or `collections.namedtuple`
|
||||
"""
|
||||
from pydantic.v1.utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
|
||||
|
||||
|
||||
def is_typeddict(type_: Type[Any]) -> bool:
|
||||
"""
|
||||
Check if a given class is a typed dict (from `typing` or `typing_extensions`)
|
||||
In 3.10, there will be a public method (https://docs.python.org/3.10/library/typing.html#typing.is_typeddict)
|
||||
"""
|
||||
from pydantic.v1.utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, dict) and hasattr(type_, '__total__')
|
||||
|
||||
|
||||
def _check_typeddict_special(type_: Any) -> bool:
|
||||
return type_ is TypedDictRequired or type_ is TypedDictNotRequired
|
||||
|
||||
|
||||
def is_typeddict_special(type_: Any) -> bool:
|
||||
"""
|
||||
Check if type is a TypedDict special form (Required or NotRequired).
|
||||
"""
|
||||
return _check_typeddict_special(type_) or _check_typeddict_special(get_origin(type_))
|
||||
|
||||
|
||||
test_type = NewType('test_type', str)
|
||||
|
||||
|
||||
def is_new_type(type_: Type[Any]) -> bool:
|
||||
"""
|
||||
Check whether type_ was created using typing.NewType
|
||||
"""
|
||||
return isinstance(type_, test_type.__class__) and hasattr(type_, '__supertype__') # type: ignore
|
||||
|
||||
|
||||
def new_type_supertype(type_: Type[Any]) -> Type[Any]:
|
||||
while hasattr(type_, '__supertype__'):
|
||||
type_ = type_.__supertype__
|
||||
return type_
|
||||
|
||||
|
||||
def _check_classvar(v: Optional[Type[Any]]) -> bool:
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
|
||||
|
||||
|
||||
def _check_finalvar(v: Optional[Type[Any]]) -> bool:
|
||||
"""
|
||||
Check if a given type is a `typing.Final` type.
|
||||
"""
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
|
||||
|
||||
|
||||
def is_classvar(ann_type: Type[Any]) -> bool:
|
||||
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
|
||||
return True
|
||||
|
||||
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
|
||||
# forward references, see #3679
|
||||
if ann_type.__class__ == ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_finalvar(ann_type: Type[Any]) -> bool:
|
||||
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
|
||||
|
||||
|
||||
def update_field_forward_refs(field: 'ModelField', globalns: Any, localns: Any) -> None:
|
||||
"""
|
||||
Try to update ForwardRefs on fields based on this ModelField, globalns and localns.
|
||||
"""
|
||||
prepare = False
|
||||
if field.type_.__class__ == ForwardRef:
|
||||
prepare = True
|
||||
field.type_ = evaluate_forwardref(field.type_, globalns, localns or None)
|
||||
if field.outer_type_.__class__ == ForwardRef:
|
||||
prepare = True
|
||||
field.outer_type_ = evaluate_forwardref(field.outer_type_, globalns, localns or None)
|
||||
if prepare:
|
||||
field.prepare()
|
||||
|
||||
if field.sub_fields:
|
||||
for sub_f in field.sub_fields:
|
||||
update_field_forward_refs(sub_f, globalns=globalns, localns=localns)
|
||||
|
||||
if field.discriminator_key is not None:
|
||||
field.prepare_discriminated_union_sub_fields()
|
||||
|
||||
|
||||
def update_model_forward_refs(
|
||||
model: Type[Any],
|
||||
fields: Iterable['ModelField'],
|
||||
json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable],
|
||||
localns: 'DictStrAny',
|
||||
exc_to_suppress: Tuple[Type[BaseException], ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Try to update model fields ForwardRefs based on model and localns.
|
||||
"""
|
||||
if model.__module__ in sys.modules:
|
||||
globalns = sys.modules[model.__module__].__dict__.copy()
|
||||
else:
|
||||
globalns = {}
|
||||
|
||||
globalns.setdefault(model.__name__, model)
|
||||
|
||||
for f in fields:
|
||||
try:
|
||||
update_field_forward_refs(f, globalns=globalns, localns=localns)
|
||||
except exc_to_suppress:
|
||||
pass
|
||||
|
||||
for key in set(json_encoders.keys()):
|
||||
if isinstance(key, str):
|
||||
fr: ForwardRef = ForwardRef(key)
|
||||
elif isinstance(key, ForwardRef):
|
||||
fr = key
|
||||
else:
|
||||
continue
|
||||
|
||||
try:
|
||||
new_key = evaluate_forwardref(fr, globalns, localns or None)
|
||||
except exc_to_suppress: # pragma: no cover
|
||||
continue
|
||||
|
||||
json_encoders[new_key] = json_encoders.pop(key)
|
||||
|
||||
|
||||
def get_class(type_: Type[Any]) -> Union[None, bool, Type[Any]]:
|
||||
"""
|
||||
Tries to get the class of a Type[T] annotation. Returns True if Type is used
|
||||
without brackets. Otherwise returns None.
|
||||
"""
|
||||
if type_ is type:
|
||||
return True
|
||||
|
||||
if get_origin(type_) is None:
|
||||
return None
|
||||
|
||||
args = get_args(type_)
|
||||
if not args or not isinstance(args[0], type):
|
||||
return True
|
||||
else:
|
||||
return args[0]
|
||||
|
||||
|
||||
def get_sub_types(tp: Any) -> List[Any]:
|
||||
"""
|
||||
Return all the types that are allowed by type `tp`
|
||||
`tp` can be a `Union` of allowed types or an `Annotated` type
|
||||
"""
|
||||
origin = get_origin(tp)
|
||||
if origin is Annotated:
|
||||
return get_sub_types(get_args(tp)[0])
|
||||
elif is_union(origin):
|
||||
return [x for t in get_args(tp) for x in get_sub_types(t)]
|
||||
else:
|
||||
return [tp]
|
||||
804
venv/lib/python3.11/site-packages/pydantic/v1/utils.py
Normal file
804
venv/lib/python3.11/site-packages/pydantic/v1/utils.py
Normal file
@@ -0,0 +1,804 @@
|
||||
import keyword
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from itertools import islice, zip_longest
|
||||
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic.v1.errors import ConfigError
|
||||
from pydantic.v1.typing import (
|
||||
NoneType,
|
||||
WithArgsTypes,
|
||||
all_literal_values,
|
||||
display_as_type,
|
||||
get_args,
|
||||
get_origin,
|
||||
is_literal_type,
|
||||
is_union,
|
||||
)
|
||||
from pydantic.v1.version import version_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inspect import Signature
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.dataclasses import Dataclass
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.main import BaseModel
|
||||
from pydantic.v1.typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
|
||||
|
||||
RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]]
|
||||
|
||||
__all__ = (
|
||||
'import_string',
|
||||
'sequence_like',
|
||||
'validate_field_name',
|
||||
'lenient_isinstance',
|
||||
'lenient_issubclass',
|
||||
'in_ipython',
|
||||
'is_valid_identifier',
|
||||
'deep_update',
|
||||
'update_not_none',
|
||||
'almost_equal_floats',
|
||||
'get_model',
|
||||
'to_camel',
|
||||
'to_lower_camel',
|
||||
'is_valid_field',
|
||||
'smart_deepcopy',
|
||||
'PyObjectStr',
|
||||
'Representation',
|
||||
'GetterDict',
|
||||
'ValueItems',
|
||||
'version_info', # required here to match behaviour in v1.3
|
||||
'ClassAttribute',
|
||||
'path_type',
|
||||
'ROOT_KEY',
|
||||
'get_unique_discriminator_alias',
|
||||
'get_discriminator_alias_and_values',
|
||||
'DUNDER_ATTRIBUTES',
|
||||
)
|
||||
|
||||
ROOT_KEY = '__root__'
|
||||
# these are types that are returned unchanged by deepcopy
|
||||
IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = {
|
||||
int,
|
||||
float,
|
||||
complex,
|
||||
str,
|
||||
bool,
|
||||
bytes,
|
||||
type,
|
||||
NoneType,
|
||||
FunctionType,
|
||||
BuiltinFunctionType,
|
||||
LambdaType,
|
||||
weakref.ref,
|
||||
CodeType,
|
||||
# note: including ModuleType will differ from behaviour of deepcopy by not producing error.
|
||||
# It might be not a good idea in general, but considering that this function used only internally
|
||||
# against default values of fields, this will allow to actually have a field with module as default value
|
||||
ModuleType,
|
||||
NotImplemented.__class__,
|
||||
Ellipsis.__class__,
|
||||
}
|
||||
|
||||
# these are types that if empty, might be copied with simple copy() instead of deepcopy()
|
||||
BUILTIN_COLLECTIONS: Set[Type[Any]] = {
|
||||
list,
|
||||
set,
|
||||
tuple,
|
||||
frozenset,
|
||||
dict,
|
||||
OrderedDict,
|
||||
defaultdict,
|
||||
deque,
|
||||
}
|
||||
|
||||
|
||||
def import_string(dotted_path: str) -> Any:
|
||||
"""
|
||||
Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the
|
||||
last name in the path. Raise ImportError if the import fails.
|
||||
"""
|
||||
from importlib import import_module
|
||||
|
||||
try:
|
||||
module_path, class_name = dotted_path.strip(' ').rsplit('.', 1)
|
||||
except ValueError as e:
|
||||
raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e
|
||||
|
||||
module = import_module(module_path)
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError as e:
|
||||
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e
|
||||
|
||||
|
||||
def truncate(v: Union[str], *, max_len: int = 80) -> str:
|
||||
"""
|
||||
Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long
|
||||
"""
|
||||
warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning)
|
||||
if isinstance(v, str) and len(v) > (max_len - 2):
|
||||
# -3 so quote + string + … + quote has correct length
|
||||
return (v[: (max_len - 3)] + '…').__repr__()
|
||||
try:
|
||||
v = v.__repr__()
|
||||
except TypeError:
|
||||
v = v.__class__.__repr__(v) # in case v is a type
|
||||
if len(v) > max_len:
|
||||
v = v[: max_len - 1] + '…'
|
||||
return v
|
||||
|
||||
|
||||
def sequence_like(v: Any) -> bool:
|
||||
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
|
||||
|
||||
|
||||
def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None:
|
||||
"""
|
||||
Ensure that the field's name does not shadow an existing attribute of the model.
|
||||
"""
|
||||
for base in bases:
|
||||
if getattr(base, field_name, None):
|
||||
raise NameError(
|
||||
f'Field name "{field_name}" shadows a BaseModel attribute; '
|
||||
f'use a different field name with "alias=\'{field_name}\'".'
|
||||
)
|
||||
|
||||
|
||||
def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
|
||||
try:
|
||||
return isinstance(o, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
if isinstance(cls, WithArgsTypes):
|
||||
return False
|
||||
raise # pragma: no cover
|
||||
|
||||
|
||||
def in_ipython() -> bool:
|
||||
"""
|
||||
Check whether we're in an ipython environment, including jupyter notebooks.
|
||||
"""
|
||||
try:
|
||||
eval('__IPYTHON__')
|
||||
except NameError:
|
||||
return False
|
||||
else: # pragma: no cover
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_identifier(identifier: str) -> bool:
|
||||
"""
|
||||
Checks that a string is a valid identifier and not a Python keyword.
|
||||
:param identifier: The identifier to test.
|
||||
:return: True if the identifier is valid.
|
||||
"""
|
||||
return identifier.isidentifier() and not keyword.iskeyword(identifier)
|
||||
|
||||
|
||||
KeyType = TypeVar('KeyType')
|
||||
|
||||
|
||||
def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]:
|
||||
updated_mapping = mapping.copy()
|
||||
for updating_mapping in updating_mappings:
|
||||
for k, v in updating_mapping.items():
|
||||
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
|
||||
updated_mapping[k] = deep_update(updated_mapping[k], v)
|
||||
else:
|
||||
updated_mapping[k] = v
|
||||
return updated_mapping
|
||||
|
||||
|
||||
def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None:
|
||||
mapping.update({k: v for k, v in update.items() if v is not None})
|
||||
|
||||
|
||||
def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool:
|
||||
"""
|
||||
Return True if two floats are almost equal
|
||||
"""
|
||||
return abs(value_1 - value_2) <= delta
|
||||
|
||||
|
||||
def generate_model_signature(
|
||||
init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig']
|
||||
) -> 'Signature':
|
||||
"""
|
||||
Generate signature for model based on its fields
|
||||
"""
|
||||
from inspect import Parameter, Signature, signature
|
||||
|
||||
from pydantic.v1.config import Extra
|
||||
|
||||
present_params = signature(init).parameters.values()
|
||||
merged_params: Dict[str, Parameter] = {}
|
||||
var_kw = None
|
||||
use_var_kw = False
|
||||
|
||||
for param in islice(present_params, 1, None): # skip self arg
|
||||
if param.kind is param.VAR_KEYWORD:
|
||||
var_kw = param
|
||||
continue
|
||||
merged_params[param.name] = param
|
||||
|
||||
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
|
||||
allow_names = config.allow_population_by_field_name
|
||||
for field_name, field in fields.items():
|
||||
param_name = field.alias
|
||||
if field_name in merged_params or param_name in merged_params:
|
||||
continue
|
||||
elif not is_valid_identifier(param_name):
|
||||
if allow_names and is_valid_identifier(field_name):
|
||||
param_name = field_name
|
||||
else:
|
||||
use_var_kw = True
|
||||
continue
|
||||
|
||||
# TODO: replace annotation with actual expected types once #1055 solved
|
||||
kwargs = {'default': field.default} if not field.required else {}
|
||||
merged_params[param_name] = Parameter(
|
||||
param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs
|
||||
)
|
||||
|
||||
if config.extra is Extra.allow:
|
||||
use_var_kw = True
|
||||
|
||||
if var_kw and use_var_kw:
|
||||
# Make sure the parameter for extra kwargs
|
||||
# does not have the same name as a field
|
||||
default_model_signature = [
|
||||
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
|
||||
('data', Parameter.VAR_KEYWORD),
|
||||
]
|
||||
if [(p.name, p.kind) for p in present_params] == default_model_signature:
|
||||
# if this is the standard model signature, use extra_data as the extra args name
|
||||
var_kw_name = 'extra_data'
|
||||
else:
|
||||
# else start from var_kw
|
||||
var_kw_name = var_kw.name
|
||||
|
||||
# generate a name that's definitely unique
|
||||
while var_kw_name in fields:
|
||||
var_kw_name += '_'
|
||||
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
|
||||
|
||||
return Signature(parameters=list(merged_params.values()), return_annotation=None)
|
||||
|
||||
|
||||
def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']:
|
||||
from pydantic.v1.main import BaseModel
|
||||
|
||||
try:
|
||||
model_cls = obj.__pydantic_model__ # type: ignore
|
||||
except AttributeError:
|
||||
model_cls = obj
|
||||
|
||||
if not issubclass(model_cls, BaseModel):
|
||||
raise TypeError('Unsupported type, must be either BaseModel or dataclass')
|
||||
return model_cls
|
||||
|
||||
|
||||
def to_camel(string: str) -> str:
|
||||
return ''.join(word.capitalize() for word in string.split('_'))
|
||||
|
||||
|
||||
def to_lower_camel(string: str) -> str:
|
||||
if len(string) >= 1:
|
||||
pascal_string = to_camel(string)
|
||||
return pascal_string[0].lower() + pascal_string[1:]
|
||||
return string.lower()
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def unique_list(
|
||||
input_list: Union[List[T], Tuple[T, ...]],
|
||||
*,
|
||||
name_factory: Callable[[T], str] = str,
|
||||
) -> List[T]:
|
||||
"""
|
||||
Make a list unique while maintaining order.
|
||||
We update the list if another one with the same name is set
|
||||
(e.g. root validator overridden in subclass)
|
||||
"""
|
||||
result: List[T] = []
|
||||
result_names: List[str] = []
|
||||
for v in input_list:
|
||||
v_name = name_factory(v)
|
||||
if v_name not in result_names:
|
||||
result_names.append(v_name)
|
||||
result.append(v)
|
||||
else:
|
||||
result[result_names.index(v_name)] = v
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class PyObjectStr(str):
|
||||
"""
|
||||
String class where repr doesn't include quotes. Useful with Representation when you want to return a string
|
||||
representation of something that valid (or pseudo-valid) python.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
|
||||
class Representation:
|
||||
"""
|
||||
Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details.
|
||||
|
||||
__pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations
|
||||
of objects.
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = tuple()
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
"""
|
||||
Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
|
||||
|
||||
Can either return:
|
||||
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
|
||||
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
|
||||
"""
|
||||
attrs = ((s, getattr(self, s)) for s in self.__slots__)
|
||||
return [(a, v) for a, v in attrs if v is not None]
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
"""
|
||||
Name of the instance's class, used in __repr__.
|
||||
"""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
|
||||
|
||||
def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]:
|
||||
"""
|
||||
Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects
|
||||
"""
|
||||
yield self.__repr_name__() + '('
|
||||
yield 1
|
||||
for name, value in self.__repr_args__():
|
||||
if name is not None:
|
||||
yield name + '='
|
||||
yield fmt(value)
|
||||
yield ','
|
||||
yield 0
|
||||
yield -1
|
||||
yield ')'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr_str__(' ')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
|
||||
|
||||
def __rich_repr__(self) -> 'RichReprResult':
|
||||
"""Get fields for Rich library"""
|
||||
for name, field_repr in self.__repr_args__():
|
||||
if name is None:
|
||||
yield field_repr
|
||||
else:
|
||||
yield name, field_repr
|
||||
|
||||
|
||||
class GetterDict(Representation):
|
||||
"""
|
||||
Hack to make object's smell just enough like dicts for validate_model.
|
||||
|
||||
We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves.
|
||||
"""
|
||||
|
||||
__slots__ = ('_obj',)
|
||||
|
||||
def __init__(self, obj: Any):
|
||||
self._obj = obj
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
try:
|
||||
return getattr(self._obj, key)
|
||||
except AttributeError as e:
|
||||
raise KeyError(key) from e
|
||||
|
||||
def get(self, key: Any, default: Any = None) -> Any:
|
||||
return getattr(self._obj, key, default)
|
||||
|
||||
def extra_keys(self) -> Set[Any]:
|
||||
"""
|
||||
We don't want to get any other attributes of obj if the model didn't explicitly ask for them
|
||||
"""
|
||||
return set()
|
||||
|
||||
def keys(self) -> List[Any]:
|
||||
"""
|
||||
Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python
|
||||
dictionaries.
|
||||
"""
|
||||
return list(self)
|
||||
|
||||
def values(self) -> List[Any]:
|
||||
return [self[k] for k in self]
|
||||
|
||||
def items(self) -> Iterator[Tuple[str, Any]]:
|
||||
for k in self:
|
||||
yield k, self.get(k)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for name in dir(self._obj):
|
||||
if not name.startswith('_'):
|
||||
yield name
|
||||
|
||||
def __len__(self) -> int:
|
||||
return sum(1 for _ in self)
|
||||
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
return item in self.keys()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return dict(self) == dict(other.items())
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [(None, dict(self))]
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
return f'GetterDict[{display_as_type(self._obj)}]'
|
||||
|
||||
|
||||
class ValueItems(Representation):
|
||||
"""
|
||||
Class for more convenient calculation of excluded or included fields on values.
|
||||
"""
|
||||
|
||||
__slots__ = ('_items', '_type')
|
||||
|
||||
def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None:
|
||||
items = self._coerce_items(items)
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
items = self._normalize_indexes(items, len(value))
|
||||
|
||||
self._items: 'MappingIntStrAny' = items
|
||||
|
||||
def is_excluded(self, item: Any) -> bool:
|
||||
"""
|
||||
Check if item is fully excluded.
|
||||
|
||||
:param item: key or index of a value
|
||||
"""
|
||||
return self.is_true(self._items.get(item))
|
||||
|
||||
def is_included(self, item: Any) -> bool:
|
||||
"""
|
||||
Check if value is contained in self._items
|
||||
|
||||
:param item: key or index of value
|
||||
"""
|
||||
return item in self._items
|
||||
|
||||
def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]:
|
||||
"""
|
||||
:param e: key or index of element on value
|
||||
:return: raw values for element if self._items is dict and contain needed element
|
||||
"""
|
||||
|
||||
item = self._items.get(e)
|
||||
return item if not self.is_true(item) else None
|
||||
|
||||
def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny':
|
||||
"""
|
||||
:param items: dict or set of indexes which will be normalized
|
||||
:param v_length: length of sequence indexes of which will be
|
||||
|
||||
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
|
||||
{0: True, 2: True, 3: True}
|
||||
>>> self._normalize_indexes({'__all__': True}, 4)
|
||||
{0: True, 1: True, 2: True, 3: True}
|
||||
"""
|
||||
|
||||
normalized_items: 'DictIntStrAny' = {}
|
||||
all_items = None
|
||||
for i, v in items.items():
|
||||
if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)):
|
||||
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
|
||||
if i == '__all__':
|
||||
all_items = self._coerce_value(v)
|
||||
continue
|
||||
if not isinstance(i, int):
|
||||
raise TypeError(
|
||||
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
|
||||
'expected integer keys or keyword "__all__"'
|
||||
)
|
||||
normalized_i = v_length + i if i < 0 else i
|
||||
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
|
||||
|
||||
if not all_items:
|
||||
return normalized_items
|
||||
if self.is_true(all_items):
|
||||
for i in range(v_length):
|
||||
normalized_items.setdefault(i, ...)
|
||||
return normalized_items
|
||||
for i in range(v_length):
|
||||
normalized_item = normalized_items.setdefault(i, {})
|
||||
if not self.is_true(normalized_item):
|
||||
normalized_items[i] = self.merge(all_items, normalized_item)
|
||||
return normalized_items
|
||||
|
||||
@classmethod
|
||||
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
|
||||
"""
|
||||
Merge a ``base`` item with an ``override`` item.
|
||||
|
||||
Both ``base`` and ``override`` are converted to dictionaries if possible.
|
||||
Sets are converted to dictionaries with the sets entries as keys and
|
||||
Ellipsis as values.
|
||||
|
||||
Each key-value pair existing in ``base`` is merged with ``override``,
|
||||
while the rest of the key-value pairs are updated recursively with this function.
|
||||
|
||||
Merging takes place based on the "union" of keys if ``intersect`` is
|
||||
set to ``False`` (default) and on the intersection of keys if
|
||||
``intersect`` is set to ``True``.
|
||||
"""
|
||||
override = cls._coerce_value(override)
|
||||
base = cls._coerce_value(base)
|
||||
if override is None:
|
||||
return base
|
||||
if cls.is_true(base) or base is None:
|
||||
return override
|
||||
if cls.is_true(override):
|
||||
return base if intersect else override
|
||||
|
||||
# intersection or union of keys while preserving ordering:
|
||||
if intersect:
|
||||
merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
|
||||
else:
|
||||
merge_keys = list(base) + [k for k in override if k not in base]
|
||||
|
||||
merged: 'DictIntStrAny' = {}
|
||||
for k in merge_keys:
|
||||
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
|
||||
if merged_item is not None:
|
||||
merged[k] = merged_item
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny':
|
||||
if isinstance(items, Mapping):
|
||||
pass
|
||||
elif isinstance(items, AbstractSet):
|
||||
items = dict.fromkeys(items, ...)
|
||||
else:
|
||||
class_name = getattr(items, '__class__', '???')
|
||||
assert_never(
|
||||
items,
|
||||
f'Unexpected type of exclude value {class_name}',
|
||||
)
|
||||
return items
|
||||
|
||||
@classmethod
|
||||
def _coerce_value(cls, value: Any) -> Any:
|
||||
if value is None or cls.is_true(value):
|
||||
return value
|
||||
return cls._coerce_items(value)
|
||||
|
||||
@staticmethod
|
||||
def is_true(v: Any) -> bool:
|
||||
return v is True or v is ...
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [(None, self._items)]
|
||||
|
||||
|
||||
class ClassAttribute:
|
||||
"""
|
||||
Hide class attribute from its instances
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'name',
|
||||
'value',
|
||||
)
|
||||
|
||||
def __init__(self, name: str, value: Any) -> None:
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __get__(self, instance: Any, owner: Type[Any]) -> None:
|
||||
if instance is None:
|
||||
return self.value
|
||||
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
|
||||
|
||||
|
||||
path_types = {
|
||||
'is_dir': 'directory',
|
||||
'is_file': 'file',
|
||||
'is_mount': 'mount point',
|
||||
'is_symlink': 'symlink',
|
||||
'is_block_device': 'block device',
|
||||
'is_char_device': 'char device',
|
||||
'is_fifo': 'FIFO',
|
||||
'is_socket': 'socket',
|
||||
}
|
||||
|
||||
|
||||
def path_type(p: 'Path') -> str:
|
||||
"""
|
||||
Find out what sort of thing a path is.
|
||||
"""
|
||||
assert p.exists(), 'path does not exist'
|
||||
for method, name in path_types.items():
|
||||
if getattr(p, method)():
|
||||
return name
|
||||
|
||||
return 'unknown'
|
||||
|
||||
|
||||
Obj = TypeVar('Obj')
|
||||
|
||||
|
||||
def smart_deepcopy(obj: Obj) -> Obj:
|
||||
"""
|
||||
Return type as is for immutable built-in types
|
||||
Use obj.copy() for built-in empty collections
|
||||
Use copy.deepcopy() for non-empty collections and unknown objects
|
||||
"""
|
||||
|
||||
obj_type = obj.__class__
|
||||
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
|
||||
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
|
||||
try:
|
||||
if not obj and obj_type in BUILTIN_COLLECTIONS:
|
||||
# faster way for empty collections, no need to copy its members
|
||||
return obj if obj_type is tuple else obj.copy() # type: ignore # tuple doesn't have copy method
|
||||
except (TypeError, ValueError, RuntimeError):
|
||||
# do we really dare to catch ALL errors? Seems a bit risky
|
||||
pass
|
||||
|
||||
return deepcopy(obj) # slowest way when we actually might need a deepcopy
|
||||
|
||||
|
||||
def is_valid_field(name: str) -> bool:
|
||||
if not name.startswith('_'):
|
||||
return True
|
||||
return ROOT_KEY == name
|
||||
|
||||
|
||||
DUNDER_ATTRIBUTES = {
|
||||
'__annotations__',
|
||||
'__classcell__',
|
||||
'__doc__',
|
||||
'__module__',
|
||||
'__orig_bases__',
|
||||
'__orig_class__',
|
||||
'__qualname__',
|
||||
}
|
||||
|
||||
|
||||
def is_valid_private_name(name: str) -> bool:
|
||||
return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES
|
||||
|
||||
|
||||
_EMPTY = object()
|
||||
|
||||
|
||||
def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool:
|
||||
"""
|
||||
Check that the items of `left` are the same objects as those in `right`.
|
||||
|
||||
>>> a, b = object(), object()
|
||||
>>> all_identical([a, b, a], [a, b, a])
|
||||
True
|
||||
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
|
||||
False
|
||||
"""
|
||||
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
|
||||
if left_item is not right_item:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def assert_never(obj: NoReturn, msg: str) -> NoReturn:
|
||||
"""
|
||||
Helper to make sure that we have covered all possible types.
|
||||
|
||||
This is mostly useful for ``mypy``, docs:
|
||||
https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks
|
||||
"""
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str:
|
||||
"""Validate that all aliases are the same and if that's the case return the alias"""
|
||||
unique_aliases = set(all_aliases)
|
||||
if len(unique_aliases) > 1:
|
||||
raise ConfigError(
|
||||
f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})'
|
||||
)
|
||||
return unique_aliases.pop()
|
||||
|
||||
|
||||
def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]:
|
||||
"""
|
||||
Get alias and all valid values in the `Literal` type of the discriminator field
|
||||
`tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many.
|
||||
"""
|
||||
is_root_model = getattr(tp, '__custom_root_type__', False)
|
||||
|
||||
if get_origin(tp) is Annotated:
|
||||
tp = get_args(tp)[0]
|
||||
|
||||
if hasattr(tp, '__pydantic_model__'):
|
||||
tp = tp.__pydantic_model__
|
||||
|
||||
if is_union(get_origin(tp)):
|
||||
alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key)
|
||||
return alias, tuple(v for values in all_values for v in values)
|
||||
elif is_root_model:
|
||||
union_type = tp.__fields__[ROOT_KEY].type_
|
||||
alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key)
|
||||
|
||||
if len(set(all_values)) > 1:
|
||||
raise ConfigError(
|
||||
f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}'
|
||||
)
|
||||
|
||||
return alias, all_values[0]
|
||||
|
||||
else:
|
||||
try:
|
||||
t_discriminator_type = tp.__fields__[discriminator_key].type_
|
||||
except AttributeError as e:
|
||||
raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e
|
||||
except KeyError as e:
|
||||
raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e
|
||||
|
||||
if not is_literal_type(t_discriminator_type):
|
||||
raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`')
|
||||
|
||||
return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type)
|
||||
|
||||
|
||||
def _get_union_alias_and_all_values(
|
||||
union_type: Type[Any], discriminator_key: str
|
||||
) -> Tuple[str, Tuple[Tuple[str, ...], ...]]:
|
||||
zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)]
|
||||
# unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))]
|
||||
all_aliases, all_values = zip(*zipped_aliases_values)
|
||||
return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values
|
||||
765
venv/lib/python3.11/site-packages/pydantic/v1/validators.py
Normal file
765
venv/lib/python3.11/site-packages/pydantic/v1/validators.py
Normal file
@@ -0,0 +1,765 @@
|
||||
import math
|
||||
import re
|
||||
from collections import OrderedDict, deque
|
||||
from collections.abc import Hashable as CollectionsHashable
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from decimal import Decimal, DecimalException
|
||||
from enum import Enum, IntEnum
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
FrozenSet,
|
||||
Generator,
|
||||
Hashable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Pattern,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic.v1 import errors
|
||||
from pydantic.v1.datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
|
||||
from pydantic.v1.typing import (
|
||||
AnyCallable,
|
||||
all_literal_values,
|
||||
display_as_type,
|
||||
get_class,
|
||||
is_callable_type,
|
||||
is_literal_type,
|
||||
is_namedtuple,
|
||||
is_none_type,
|
||||
is_typeddict,
|
||||
)
|
||||
from pydantic.v1.utils import almost_equal_floats, lenient_issubclass, sequence_like
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Literal, TypedDict
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt
|
||||
|
||||
ConstrainedNumber = Union[ConstrainedDecimal, ConstrainedFloat, ConstrainedInt]
|
||||
AnyOrderedDict = OrderedDict[Any, Any]
|
||||
Number = Union[int, float, Decimal]
|
||||
StrBytes = Union[str, bytes]
|
||||
|
||||
|
||||
def str_validator(v: Any) -> Union[str]:
|
||||
if isinstance(v, str):
|
||||
if isinstance(v, Enum):
|
||||
return v.value
|
||||
else:
|
||||
return v
|
||||
elif isinstance(v, (float, int, Decimal)):
|
||||
# is there anything else we want to add here? If you think so, create an issue.
|
||||
return str(v)
|
||||
elif isinstance(v, (bytes, bytearray)):
|
||||
return v.decode()
|
||||
else:
|
||||
raise errors.StrError()
|
||||
|
||||
|
||||
def strict_str_validator(v: Any) -> Union[str]:
|
||||
if isinstance(v, str) and not isinstance(v, Enum):
|
||||
return v
|
||||
raise errors.StrError()
|
||||
|
||||
|
||||
def bytes_validator(v: Any) -> Union[bytes]:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
elif isinstance(v, bytearray):
|
||||
return bytes(v)
|
||||
elif isinstance(v, str):
|
||||
return v.encode()
|
||||
elif isinstance(v, (float, int, Decimal)):
|
||||
return str(v).encode()
|
||||
else:
|
||||
raise errors.BytesError()
|
||||
|
||||
|
||||
def strict_bytes_validator(v: Any) -> Union[bytes]:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
elif isinstance(v, bytearray):
|
||||
return bytes(v)
|
||||
else:
|
||||
raise errors.BytesError()
|
||||
|
||||
|
||||
BOOL_FALSE = {0, '0', 'off', 'f', 'false', 'n', 'no'}
|
||||
BOOL_TRUE = {1, '1', 'on', 't', 'true', 'y', 'yes'}
|
||||
|
||||
|
||||
def bool_validator(v: Any) -> bool:
|
||||
if v is True or v is False:
|
||||
return v
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode()
|
||||
if isinstance(v, str):
|
||||
v = v.lower()
|
||||
try:
|
||||
if v in BOOL_TRUE:
|
||||
return True
|
||||
if v in BOOL_FALSE:
|
||||
return False
|
||||
except TypeError:
|
||||
raise errors.BoolError()
|
||||
raise errors.BoolError()
|
||||
|
||||
|
||||
# matches the default limit cpython, see https://github.com/python/cpython/pull/96500
|
||||
max_str_int = 4_300
|
||||
|
||||
|
||||
def int_validator(v: Any) -> int:
|
||||
if isinstance(v, int) and not (v is True or v is False):
|
||||
return v
|
||||
|
||||
# see https://github.com/pydantic/pydantic/issues/1477 and in turn, https://github.com/python/cpython/issues/95778
|
||||
# this check should be unnecessary once patch releases are out for 3.7, 3.8, 3.9 and 3.10
|
||||
# but better to check here until then.
|
||||
# NOTICE: this does not fully protect user from the DOS risk since the standard library JSON implementation
|
||||
# (and other std lib modules like xml) use `int()` and are likely called before this, the best workaround is to
|
||||
# 1. update to the latest patch release of python once released, 2. use a different JSON library like ujson
|
||||
if isinstance(v, (str, bytes, bytearray)) and len(v) > max_str_int:
|
||||
raise errors.IntegerError()
|
||||
|
||||
try:
|
||||
return int(v)
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
raise errors.IntegerError()
|
||||
|
||||
|
||||
def strict_int_validator(v: Any) -> int:
|
||||
if isinstance(v, int) and not (v is True or v is False):
|
||||
return v
|
||||
raise errors.IntegerError()
|
||||
|
||||
|
||||
def float_validator(v: Any) -> float:
|
||||
if isinstance(v, float):
|
||||
return v
|
||||
|
||||
try:
|
||||
return float(v)
|
||||
except (TypeError, ValueError):
|
||||
raise errors.FloatError()
|
||||
|
||||
|
||||
def strict_float_validator(v: Any) -> float:
|
||||
if isinstance(v, float):
|
||||
return v
|
||||
raise errors.FloatError()
|
||||
|
||||
|
||||
def float_finite_validator(v: 'Number', field: 'ModelField', config: 'BaseConfig') -> 'Number':
|
||||
allow_inf_nan = getattr(field.type_, 'allow_inf_nan', None)
|
||||
if allow_inf_nan is None:
|
||||
allow_inf_nan = config.allow_inf_nan
|
||||
|
||||
if allow_inf_nan is False and (math.isnan(v) or math.isinf(v)):
|
||||
raise errors.NumberNotFiniteError()
|
||||
return v
|
||||
|
||||
|
||||
def number_multiple_validator(v: 'Number', field: 'ModelField') -> 'Number':
|
||||
field_type: ConstrainedNumber = field.type_
|
||||
if field_type.multiple_of is not None:
|
||||
mod = float(v) / float(field_type.multiple_of) % 1
|
||||
if not almost_equal_floats(mod, 0.0) and not almost_equal_floats(mod, 1.0):
|
||||
raise errors.NumberNotMultipleError(multiple_of=field_type.multiple_of)
|
||||
return v
|
||||
|
||||
|
||||
def number_size_validator(v: 'Number', field: 'ModelField') -> 'Number':
|
||||
field_type: ConstrainedNumber = field.type_
|
||||
if field_type.gt is not None and not v > field_type.gt:
|
||||
raise errors.NumberNotGtError(limit_value=field_type.gt)
|
||||
elif field_type.ge is not None and not v >= field_type.ge:
|
||||
raise errors.NumberNotGeError(limit_value=field_type.ge)
|
||||
|
||||
if field_type.lt is not None and not v < field_type.lt:
|
||||
raise errors.NumberNotLtError(limit_value=field_type.lt)
|
||||
if field_type.le is not None and not v <= field_type.le:
|
||||
raise errors.NumberNotLeError(limit_value=field_type.le)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constant_validator(v: 'Any', field: 'ModelField') -> 'Any':
|
||||
"""Validate ``const`` fields.
|
||||
|
||||
The value provided for a ``const`` field must be equal to the default value
|
||||
of the field. This is to support the keyword of the same name in JSON
|
||||
Schema.
|
||||
"""
|
||||
if v != field.default:
|
||||
raise errors.WrongConstantError(given=v, permitted=[field.default])
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def anystr_length_validator(v: 'StrBytes', config: 'BaseConfig') -> 'StrBytes':
|
||||
v_len = len(v)
|
||||
|
||||
min_length = config.min_anystr_length
|
||||
if v_len < min_length:
|
||||
raise errors.AnyStrMinLengthError(limit_value=min_length)
|
||||
|
||||
max_length = config.max_anystr_length
|
||||
if max_length is not None and v_len > max_length:
|
||||
raise errors.AnyStrMaxLengthError(limit_value=max_length)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def anystr_strip_whitespace(v: 'StrBytes') -> 'StrBytes':
|
||||
return v.strip()
|
||||
|
||||
|
||||
def anystr_upper(v: 'StrBytes') -> 'StrBytes':
|
||||
return v.upper()
|
||||
|
||||
|
||||
def anystr_lower(v: 'StrBytes') -> 'StrBytes':
|
||||
return v.lower()
|
||||
|
||||
|
||||
def ordered_dict_validator(v: Any) -> 'AnyOrderedDict':
|
||||
if isinstance(v, OrderedDict):
|
||||
return v
|
||||
|
||||
try:
|
||||
return OrderedDict(v)
|
||||
except (TypeError, ValueError):
|
||||
raise errors.DictError()
|
||||
|
||||
|
||||
def dict_validator(v: Any) -> Dict[Any, Any]:
|
||||
if isinstance(v, dict):
|
||||
return v
|
||||
|
||||
try:
|
||||
return dict(v)
|
||||
except (TypeError, ValueError):
|
||||
raise errors.DictError()
|
||||
|
||||
|
||||
def list_validator(v: Any) -> List[Any]:
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return list(v)
|
||||
else:
|
||||
raise errors.ListError()
|
||||
|
||||
|
||||
def tuple_validator(v: Any) -> Tuple[Any, ...]:
|
||||
if isinstance(v, tuple):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return tuple(v)
|
||||
else:
|
||||
raise errors.TupleError()
|
||||
|
||||
|
||||
def set_validator(v: Any) -> Set[Any]:
|
||||
if isinstance(v, set):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return set(v)
|
||||
else:
|
||||
raise errors.SetError()
|
||||
|
||||
|
||||
def frozenset_validator(v: Any) -> FrozenSet[Any]:
|
||||
if isinstance(v, frozenset):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return frozenset(v)
|
||||
else:
|
||||
raise errors.FrozenSetError()
|
||||
|
||||
|
||||
def deque_validator(v: Any) -> Deque[Any]:
|
||||
if isinstance(v, deque):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return deque(v)
|
||||
else:
|
||||
raise errors.DequeError()
|
||||
|
||||
|
||||
def enum_member_validator(v: Any, field: 'ModelField', config: 'BaseConfig') -> Enum:
|
||||
try:
|
||||
enum_v = field.type_(v)
|
||||
except ValueError:
|
||||
# field.type_ should be an enum, so will be iterable
|
||||
raise errors.EnumMemberError(enum_values=list(field.type_))
|
||||
return enum_v.value if config.use_enum_values else enum_v
|
||||
|
||||
|
||||
def uuid_validator(v: Any, field: 'ModelField') -> UUID:
|
||||
try:
|
||||
if isinstance(v, str):
|
||||
v = UUID(v)
|
||||
elif isinstance(v, (bytes, bytearray)):
|
||||
try:
|
||||
v = UUID(v.decode())
|
||||
except ValueError:
|
||||
# 16 bytes in big-endian order as the bytes argument fail
|
||||
# the above check
|
||||
v = UUID(bytes=v)
|
||||
except ValueError:
|
||||
raise errors.UUIDError()
|
||||
|
||||
if not isinstance(v, UUID):
|
||||
raise errors.UUIDError()
|
||||
|
||||
required_version = getattr(field.type_, '_required_version', None)
|
||||
if required_version and v.version != required_version:
|
||||
raise errors.UUIDVersionError(required_version=required_version)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def decimal_validator(v: Any) -> Decimal:
|
||||
if isinstance(v, Decimal):
|
||||
return v
|
||||
elif isinstance(v, (bytes, bytearray)):
|
||||
v = v.decode()
|
||||
|
||||
v = str(v).strip()
|
||||
|
||||
try:
|
||||
v = Decimal(v)
|
||||
except DecimalException:
|
||||
raise errors.DecimalError()
|
||||
|
||||
if not v.is_finite():
|
||||
raise errors.DecimalIsNotFiniteError()
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def hashable_validator(v: Any) -> Hashable:
|
||||
if isinstance(v, Hashable):
|
||||
return v
|
||||
|
||||
raise errors.HashableError()
|
||||
|
||||
|
||||
def ip_v4_address_validator(v: Any) -> IPv4Address:
|
||||
if isinstance(v, IPv4Address):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv4Address(v)
|
||||
except ValueError:
|
||||
raise errors.IPv4AddressError()
|
||||
|
||||
|
||||
def ip_v6_address_validator(v: Any) -> IPv6Address:
|
||||
if isinstance(v, IPv6Address):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv6Address(v)
|
||||
except ValueError:
|
||||
raise errors.IPv6AddressError()
|
||||
|
||||
|
||||
def ip_v4_network_validator(v: Any) -> IPv4Network:
|
||||
"""
|
||||
Assume IPv4Network initialised with a default ``strict`` argument
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
|
||||
"""
|
||||
if isinstance(v, IPv4Network):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv4Network(v)
|
||||
except ValueError:
|
||||
raise errors.IPv4NetworkError()
|
||||
|
||||
|
||||
def ip_v6_network_validator(v: Any) -> IPv6Network:
|
||||
"""
|
||||
Assume IPv6Network initialised with a default ``strict`` argument
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
|
||||
"""
|
||||
if isinstance(v, IPv6Network):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv6Network(v)
|
||||
except ValueError:
|
||||
raise errors.IPv6NetworkError()
|
||||
|
||||
|
||||
def ip_v4_interface_validator(v: Any) -> IPv4Interface:
|
||||
if isinstance(v, IPv4Interface):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv4Interface(v)
|
||||
except ValueError:
|
||||
raise errors.IPv4InterfaceError()
|
||||
|
||||
|
||||
def ip_v6_interface_validator(v: Any) -> IPv6Interface:
|
||||
if isinstance(v, IPv6Interface):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv6Interface(v)
|
||||
except ValueError:
|
||||
raise errors.IPv6InterfaceError()
|
||||
|
||||
|
||||
def path_validator(v: Any) -> Path:
|
||||
if isinstance(v, Path):
|
||||
return v
|
||||
|
||||
try:
|
||||
return Path(v)
|
||||
except TypeError:
|
||||
raise errors.PathError()
|
||||
|
||||
|
||||
def path_exists_validator(v: Any) -> Path:
|
||||
if not v.exists():
|
||||
raise errors.PathNotExistsError(path=v)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def callable_validator(v: Any) -> AnyCallable:
|
||||
"""
|
||||
Perform a simple check if the value is callable.
|
||||
|
||||
Note: complete matching of argument type hints and return types is not performed
|
||||
"""
|
||||
if callable(v):
|
||||
return v
|
||||
|
||||
raise errors.CallableError(value=v)
|
||||
|
||||
|
||||
def enum_validator(v: Any) -> Enum:
|
||||
if isinstance(v, Enum):
|
||||
return v
|
||||
|
||||
raise errors.EnumError(value=v)
|
||||
|
||||
|
||||
def int_enum_validator(v: Any) -> IntEnum:
|
||||
if isinstance(v, IntEnum):
|
||||
return v
|
||||
|
||||
raise errors.IntEnumError(value=v)
|
||||
|
||||
|
||||
def make_literal_validator(type_: Any) -> Callable[[Any], Any]:
|
||||
permitted_choices = all_literal_values(type_)
|
||||
|
||||
# To have a O(1) complexity and still return one of the values set inside the `Literal`,
|
||||
# we create a dict with the set values (a set causes some problems with the way intersection works).
|
||||
# In some cases the set value and checked value can indeed be different (see `test_literal_validator_str_enum`)
|
||||
allowed_choices = {v: v for v in permitted_choices}
|
||||
|
||||
def literal_validator(v: Any) -> Any:
|
||||
try:
|
||||
return allowed_choices[v]
|
||||
except (KeyError, TypeError):
|
||||
raise errors.WrongConstantError(given=v, permitted=permitted_choices)
|
||||
|
||||
return literal_validator
|
||||
|
||||
|
||||
def constr_length_validator(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
v_len = len(v)
|
||||
|
||||
min_length = field.type_.min_length if field.type_.min_length is not None else config.min_anystr_length
|
||||
if v_len < min_length:
|
||||
raise errors.AnyStrMinLengthError(limit_value=min_length)
|
||||
|
||||
max_length = field.type_.max_length if field.type_.max_length is not None else config.max_anystr_length
|
||||
if max_length is not None and v_len > max_length:
|
||||
raise errors.AnyStrMaxLengthError(limit_value=max_length)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constr_strip_whitespace(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
strip_whitespace = field.type_.strip_whitespace or config.anystr_strip_whitespace
|
||||
if strip_whitespace:
|
||||
v = v.strip()
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constr_upper(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
upper = field.type_.to_upper or config.anystr_upper
|
||||
if upper:
|
||||
v = v.upper()
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constr_lower(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
lower = field.type_.to_lower or config.anystr_lower
|
||||
if lower:
|
||||
v = v.lower()
|
||||
return v
|
||||
|
||||
|
||||
def validate_json(v: Any, config: 'BaseConfig') -> Any:
|
||||
if v is None:
|
||||
# pass None through to other validators
|
||||
return v
|
||||
try:
|
||||
return config.json_loads(v) # type: ignore
|
||||
except ValueError:
|
||||
raise errors.JsonError()
|
||||
except TypeError:
|
||||
raise errors.JsonTypeError()
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def make_arbitrary_type_validator(type_: Type[T]) -> Callable[[T], T]:
|
||||
def arbitrary_type_validator(v: Any) -> T:
|
||||
if isinstance(v, type_):
|
||||
return v
|
||||
raise errors.ArbitraryTypeError(expected_arbitrary_type=type_)
|
||||
|
||||
return arbitrary_type_validator
|
||||
|
||||
|
||||
def make_class_validator(type_: Type[T]) -> Callable[[Any], Type[T]]:
|
||||
def class_validator(v: Any) -> Type[T]:
|
||||
if lenient_issubclass(v, type_):
|
||||
return v
|
||||
raise errors.SubclassError(expected_class=type_)
|
||||
|
||||
return class_validator
|
||||
|
||||
|
||||
def any_class_validator(v: Any) -> Type[T]:
|
||||
if isinstance(v, type):
|
||||
return v
|
||||
raise errors.ClassError()
|
||||
|
||||
|
||||
def none_validator(v: Any) -> 'Literal[None]':
|
||||
if v is None:
|
||||
return v
|
||||
raise errors.NotNoneError()
|
||||
|
||||
|
||||
def pattern_validator(v: Any) -> Pattern[str]:
|
||||
if isinstance(v, Pattern):
|
||||
return v
|
||||
|
||||
str_value = str_validator(v)
|
||||
|
||||
try:
|
||||
return re.compile(str_value)
|
||||
except re.error:
|
||||
raise errors.PatternError()
|
||||
|
||||
|
||||
NamedTupleT = TypeVar('NamedTupleT', bound=NamedTuple)
|
||||
|
||||
|
||||
def make_namedtuple_validator(
|
||||
namedtuple_cls: Type[NamedTupleT], config: Type['BaseConfig']
|
||||
) -> Callable[[Tuple[Any, ...]], NamedTupleT]:
|
||||
from pydantic.v1.annotated_types import create_model_from_namedtuple
|
||||
|
||||
NamedTupleModel = create_model_from_namedtuple(
|
||||
namedtuple_cls,
|
||||
__config__=config,
|
||||
__module__=namedtuple_cls.__module__,
|
||||
)
|
||||
namedtuple_cls.__pydantic_model__ = NamedTupleModel # type: ignore[attr-defined]
|
||||
|
||||
def namedtuple_validator(values: Tuple[Any, ...]) -> NamedTupleT:
|
||||
annotations = NamedTupleModel.__annotations__
|
||||
|
||||
if len(values) > len(annotations):
|
||||
raise errors.ListMaxLengthError(limit_value=len(annotations))
|
||||
|
||||
dict_values: Dict[str, Any] = dict(zip(annotations, values))
|
||||
validated_dict_values: Dict[str, Any] = dict(NamedTupleModel(**dict_values))
|
||||
return namedtuple_cls(**validated_dict_values)
|
||||
|
||||
return namedtuple_validator
|
||||
|
||||
|
||||
def make_typeddict_validator(
|
||||
typeddict_cls: Type['TypedDict'], config: Type['BaseConfig'] # type: ignore[valid-type]
|
||||
) -> Callable[[Any], Dict[str, Any]]:
|
||||
from pydantic.v1.annotated_types import create_model_from_typeddict
|
||||
|
||||
TypedDictModel = create_model_from_typeddict(
|
||||
typeddict_cls,
|
||||
__config__=config,
|
||||
__module__=typeddict_cls.__module__,
|
||||
)
|
||||
typeddict_cls.__pydantic_model__ = TypedDictModel # type: ignore[attr-defined]
|
||||
|
||||
def typeddict_validator(values: 'TypedDict') -> Dict[str, Any]: # type: ignore[valid-type]
|
||||
return TypedDictModel.parse_obj(values).dict(exclude_unset=True)
|
||||
|
||||
return typeddict_validator
|
||||
|
||||
|
||||
class IfConfig:
|
||||
def __init__(self, validator: AnyCallable, *config_attr_names: str, ignored_value: Any = False) -> None:
|
||||
self.validator = validator
|
||||
self.config_attr_names = config_attr_names
|
||||
self.ignored_value = ignored_value
|
||||
|
||||
def check(self, config: Type['BaseConfig']) -> bool:
|
||||
return any(getattr(config, name) not in {None, self.ignored_value} for name in self.config_attr_names)
|
||||
|
||||
|
||||
# order is important here, for example: bool is a subclass of int so has to come first, datetime before date same,
|
||||
# IPv4Interface before IPv4Address, etc
|
||||
_VALIDATORS: List[Tuple[Type[Any], List[Any]]] = [
|
||||
(IntEnum, [int_validator, enum_member_validator]),
|
||||
(Enum, [enum_member_validator]),
|
||||
(
|
||||
str,
|
||||
[
|
||||
str_validator,
|
||||
IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'),
|
||||
IfConfig(anystr_upper, 'anystr_upper'),
|
||||
IfConfig(anystr_lower, 'anystr_lower'),
|
||||
IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'),
|
||||
],
|
||||
),
|
||||
(
|
||||
bytes,
|
||||
[
|
||||
bytes_validator,
|
||||
IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'),
|
||||
IfConfig(anystr_upper, 'anystr_upper'),
|
||||
IfConfig(anystr_lower, 'anystr_lower'),
|
||||
IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'),
|
||||
],
|
||||
),
|
||||
(bool, [bool_validator]),
|
||||
(int, [int_validator]),
|
||||
(float, [float_validator, IfConfig(float_finite_validator, 'allow_inf_nan', ignored_value=True)]),
|
||||
(Path, [path_validator]),
|
||||
(datetime, [parse_datetime]),
|
||||
(date, [parse_date]),
|
||||
(time, [parse_time]),
|
||||
(timedelta, [parse_duration]),
|
||||
(OrderedDict, [ordered_dict_validator]),
|
||||
(dict, [dict_validator]),
|
||||
(list, [list_validator]),
|
||||
(tuple, [tuple_validator]),
|
||||
(set, [set_validator]),
|
||||
(frozenset, [frozenset_validator]),
|
||||
(deque, [deque_validator]),
|
||||
(UUID, [uuid_validator]),
|
||||
(Decimal, [decimal_validator]),
|
||||
(IPv4Interface, [ip_v4_interface_validator]),
|
||||
(IPv6Interface, [ip_v6_interface_validator]),
|
||||
(IPv4Address, [ip_v4_address_validator]),
|
||||
(IPv6Address, [ip_v6_address_validator]),
|
||||
(IPv4Network, [ip_v4_network_validator]),
|
||||
(IPv6Network, [ip_v6_network_validator]),
|
||||
]
|
||||
|
||||
|
||||
def find_validators( # noqa: C901 (ignore complexity)
|
||||
type_: Type[Any], config: Type['BaseConfig']
|
||||
) -> Generator[AnyCallable, None, None]:
|
||||
from pydantic.v1.dataclasses import is_builtin_dataclass, make_dataclass_validator
|
||||
|
||||
if type_ is Any or type_ is object:
|
||||
return
|
||||
type_type = type_.__class__
|
||||
if type_type == ForwardRef or type_type == TypeVar:
|
||||
return
|
||||
|
||||
if is_none_type(type_):
|
||||
yield none_validator
|
||||
return
|
||||
if type_ is Pattern or type_ is re.Pattern:
|
||||
yield pattern_validator
|
||||
return
|
||||
if type_ is Hashable or type_ is CollectionsHashable:
|
||||
yield hashable_validator
|
||||
return
|
||||
if is_callable_type(type_):
|
||||
yield callable_validator
|
||||
return
|
||||
if is_literal_type(type_):
|
||||
yield make_literal_validator(type_)
|
||||
return
|
||||
if is_builtin_dataclass(type_):
|
||||
yield from make_dataclass_validator(type_, config)
|
||||
return
|
||||
if type_ is Enum:
|
||||
yield enum_validator
|
||||
return
|
||||
if type_ is IntEnum:
|
||||
yield int_enum_validator
|
||||
return
|
||||
if is_namedtuple(type_):
|
||||
yield tuple_validator
|
||||
yield make_namedtuple_validator(type_, config)
|
||||
return
|
||||
if is_typeddict(type_):
|
||||
yield make_typeddict_validator(type_, config)
|
||||
return
|
||||
|
||||
class_ = get_class(type_)
|
||||
if class_ is not None:
|
||||
if class_ is not Any and isinstance(class_, type):
|
||||
yield make_class_validator(class_)
|
||||
else:
|
||||
yield any_class_validator
|
||||
return
|
||||
|
||||
for val_type, validators in _VALIDATORS:
|
||||
try:
|
||||
if issubclass(type_, val_type):
|
||||
for v in validators:
|
||||
if isinstance(v, IfConfig):
|
||||
if v.check(config):
|
||||
yield v.validator
|
||||
else:
|
||||
yield v
|
||||
return
|
||||
except TypeError:
|
||||
raise RuntimeError(f'error checking inheritance of {type_!r} (type: {display_as_type(type_)})')
|
||||
|
||||
if config.arbitrary_types_allowed:
|
||||
yield make_arbitrary_type_validator(type_)
|
||||
else:
|
||||
raise RuntimeError(f'no validator found for {type_}, see `arbitrary_types_allowed` in Config')
|
||||
38
venv/lib/python3.11/site-packages/pydantic/v1/version.py
Normal file
38
venv/lib/python3.11/site-packages/pydantic/v1/version.py
Normal file
@@ -0,0 +1,38 @@
|
||||
__all__ = 'compiled', 'VERSION', 'version_info'
|
||||
|
||||
VERSION = '1.10.18'
|
||||
|
||||
try:
|
||||
import cython # type: ignore
|
||||
except ImportError:
|
||||
compiled: bool = False
|
||||
else: # pragma: no cover
|
||||
try:
|
||||
compiled = cython.compiled
|
||||
except AttributeError:
|
||||
compiled = False
|
||||
|
||||
|
||||
def version_info() -> str:
|
||||
import platform
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
optional_deps = []
|
||||
for p in ('devtools', 'dotenv', 'email-validator', 'typing-extensions'):
|
||||
try:
|
||||
import_module(p.replace('-', '_'))
|
||||
except ImportError:
|
||||
continue
|
||||
optional_deps.append(p)
|
||||
|
||||
info = {
|
||||
'pydantic version': VERSION,
|
||||
'pydantic compiled': compiled,
|
||||
'install path': Path(__file__).resolve().parent,
|
||||
'python version': sys.version,
|
||||
'platform': platform.platform(),
|
||||
'optional deps. installed': optional_deps,
|
||||
}
|
||||
return '\n'.join('{:>30} {}'.format(k + ':', str(v).replace('\n', ' ')) for k, v in info.items())
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user