diff --git a/flag_engine/context/types.py b/flag_engine/context/types.py index 06e48480..193e573a 100644 --- a/flag_engine/context/types.py +++ b/flag_engine/context/types.py @@ -4,11 +4,16 @@ from __future__ import annotations -from typing import Any, Dict, List, Literal, Optional, TypedDict, Union +from typing import Any, Dict, Generic, List, Literal, Optional, Union -from typing_extensions import NotRequired +from typing_extensions import NotRequired, TypedDict -from flag_engine.segments.types import ConditionOperator, ContextValue, RuleType +from flag_engine.segments.types import ( + ConditionOperator, + ContextValue, + RuleType, + SegmentMetadataT, +) class EnvironmentContext(TypedDict): @@ -58,16 +63,16 @@ class FeatureContext(TypedDict): priority: NotRequired[float] -class SegmentContext(TypedDict): +class SegmentContext(TypedDict, Generic[SegmentMetadataT]): key: str name: str rules: List[SegmentRule] overrides: NotRequired[List[FeatureContext]] - metadata: NotRequired[Dict[str, Any]] + metadata: NotRequired[SegmentMetadataT] -class EvaluationContext(TypedDict): +class EvaluationContext(TypedDict, Generic[SegmentMetadataT]): environment: EnvironmentContext identity: NotRequired[Optional[IdentityContext]] - segments: NotRequired[Dict[str, SegmentContext]] + segments: NotRequired[Dict[str, SegmentContext[SegmentMetadataT]]] features: NotRequired[Dict[str, FeatureContext]] diff --git a/flag_engine/result/types.py b/flag_engine/result/types.py index 1404d1c4..a6741e1c 100644 --- a/flag_engine/result/types.py +++ b/flag_engine/result/types.py @@ -4,9 +4,11 @@ from __future__ import annotations -from typing import Any, Dict, List, TypedDict +from typing import Any, Dict, Generic, List -from typing_extensions import NotRequired +from typing_extensions import NotRequired, TypedDict + +from flag_engine.segments.types import SegmentMetadataT class FlagResult(TypedDict): @@ -17,12 +19,12 @@ class FlagResult(TypedDict): reason: str -class SegmentResult(TypedDict): +class SegmentResult(TypedDict, Generic[SegmentMetadataT]): key: str name: str - metadata: NotRequired[Dict[str, Any]] + metadata: NotRequired[SegmentMetadataT] -class EvaluationResult(TypedDict): +class EvaluationResult(TypedDict, Generic[SegmentMetadataT]): flags: Dict[str, FlagResult] - segments: List[SegmentResult] + segments: List[SegmentResult[SegmentMetadataT]] diff --git a/flag_engine/segments/evaluator.py b/flag_engine/segments/evaluator.py index f82f2b30..6cfc7999 100644 --- a/flag_engine/segments/evaluator.py +++ b/flag_engine/segments/evaluator.py @@ -20,7 +20,12 @@ ) from flag_engine.result.types import EvaluationResult, FlagResult, SegmentResult from flag_engine.segments import constants -from flag_engine.segments.types import ConditionOperator, ContextValue, is_context_value +from flag_engine.segments.types import ( + ConditionOperator, + ContextValue, + SegmentMetadataT, + is_context_value, +) from flag_engine.segments.utils import escape_double_quotes, get_matching_function from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids from flag_engine.utils.semver import is_semver @@ -32,14 +37,16 @@ class FeatureContextWithSegmentName(typing.TypedDict): segment_name: str -def get_evaluation_result(context: EvaluationContext) -> EvaluationResult: +def get_evaluation_result( + context: EvaluationContext[SegmentMetadataT], +) -> EvaluationResult[SegmentMetadataT]: """ Get the evaluation result for a given context. :param context: the evaluation context :return: EvaluationResult containing the context, flags, and segments """ - segments: list[SegmentResult] = [] + segments: list[SegmentResult[SegmentMetadataT]] = [] flags: dict[str, FlagResult] = {} segment_feature_contexts: dict[SupportsStr, FeatureContextWithSegmentName] = {} @@ -48,7 +55,7 @@ def get_evaluation_result(context: EvaluationContext) -> EvaluationResult: if not is_context_in_segment(context, segment_context): continue - segment_result: SegmentResult = { + segment_result: SegmentResult[SegmentMetadataT] = { "key": segment_context["key"], "name": segment_context["name"], } @@ -152,8 +159,8 @@ def get_flag_result_from_feature_context( def is_context_in_segment( - context: EvaluationContext, - segment_context: SegmentContext, + context: EvaluationContext[SegmentMetadataT], + segment_context: SegmentContext[SegmentMetadataT], ) -> bool: return bool(rules := segment_context["rules"]) and all( context_matches_rule( @@ -164,7 +171,7 @@ def is_context_in_segment( def context_matches_rule( - context: EvaluationContext, + context: EvaluationContext[SegmentMetadataT], rule: SegmentRule, segment_key: SupportsStr, ) -> bool: @@ -194,7 +201,7 @@ def context_matches_rule( def context_matches_condition( - context: EvaluationContext, + context: EvaluationContext[SegmentMetadataT], condition: SegmentCondition, segment_key: SupportsStr, ) -> bool: @@ -255,7 +262,7 @@ def context_matches_condition( def get_context_value( - context: EvaluationContext, + context: EvaluationContext[SegmentMetadataT], property: str, ) -> ContextValue: value = None @@ -353,7 +360,7 @@ def inner( @lru_cache def _get_context_value_getter( property: str, -) -> typing.Callable[[EvaluationContext], ContextValue]: +) -> typing.Callable[[EvaluationContext[SegmentMetadataT]], ContextValue]: """ Get a function to retrieve a context value based on property value, assumed to be either a JSONPath string or a trait key. @@ -370,7 +377,7 @@ def _get_context_value_getter( f'$.identity.traits["{escape_double_quotes(property)}"]', ) - def getter(context: EvaluationContext) -> ContextValue: + def getter(context: EvaluationContext[SegmentMetadataT]) -> ContextValue: if typing.TYPE_CHECKING: # pragma: no cover # Ugly hack to satisfy mypy :( data = dict(context) diff --git a/flag_engine/segments/types.py b/flag_engine/segments/types.py index 01d46f47..118b0d37 100644 --- a/flag_engine/segments/types.py +++ b/flag_engine/segments/types.py @@ -1,6 +1,10 @@ -from typing import Any, Literal, Union, get_args +from __future__ import annotations -from typing_extensions import TypeGuard +from typing import Any, Dict, Literal, Union, get_args + +from typing_extensions import TypeGuard, TypeVar + +SegmentMetadataT = TypeVar("SegmentMetadataT", default=Dict[str, object]) ConditionOperator = Literal[ "EQUAL", diff --git a/flag_engine/types/__init__.py b/flag_engine/types/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unit/test_engine.py b/tests/unit/test_engine.py index fe0fa96f..980c2002 100644 --- a/tests/unit/test_engine.py +++ b/tests/unit/test_engine.py @@ -1,4 +1,11 @@ import json +from typing import TYPE_CHECKING, TypedDict + +if not TYPE_CHECKING: + # `reveal_type` is a pseudo-builtin only available when type checking. + # Define a no-op version here so that we can call it in the tests. + def reveal_type(x: object) -> None: ... + from flag_engine.context.types import EvaluationContext, IdentityContext, SegmentContext from flag_engine.engine import get_evaluation_result @@ -357,3 +364,80 @@ def test_get_evaluation_result__segment_override__no_priority__returns_expected( {"key": "3", "name": "another_segment"}, ], } + + +def test_segment_metadata_generic_type__returns_expected() -> None: + # Given + class CustomMetadata(TypedDict): + foo: str + bar: int + + segment_metadata = CustomMetadata(foo="hello", bar=123) + + evaluation_context: EvaluationContext[CustomMetadata] = { + "environment": {"key": "api-key", "name": ""}, + "segments": { + "1": { + "key": "1", + "name": "my_segment", + "rules": [ + { + "type": "ALL", + "conditions": [ + { + "property": "$.environment.name", + "operator": "EQUAL", + "value": "", + } + ], + "rules": [], + } + ], + "metadata": segment_metadata, + }, + }, + } + + # When + result = get_evaluation_result(evaluation_context) + + # Then + assert result["segments"][0]["metadata"] is segment_metadata + reveal_type(result["segments"][0]["metadata"]) # CustomMetadata + + +def test_segment_metadata_generic_type__default__returns_expected() -> None: + # Given + segment_metadata = {"hello": object()} + + # we don't specify generic type, but mypy is happy with this + evaluation_context: EvaluationContext = { + "environment": {"key": "api-key", "name": ""}, + "segments": { + "1": { + "key": "1", + "name": "my_segment", + "rules": [ + { + "type": "ALL", + "conditions": [ + { + "property": "$.environment.name", + "operator": "EQUAL", + "value": "", + } + ], + "rules": [], + } + ], + "metadata": segment_metadata, + }, + }, + } + + # When + result = get_evaluation_result(evaluation_context) + + # Then + assert result["segments"][0]["metadata"] is segment_metadata + reveal_type(result["segments"][0]["metadata"]) # Dict[str, object]