Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions flag_engine/context/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@

from __future__ import annotations

from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
from typing import Any, Dict, Generic, List, Literal, Optional, Union

from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypedDict

from flag_engine.segments.types import ConditionOperator, ContextValue, RuleType
from flag_engine.segments.types import (
ConditionOperator,
ContextValue,
RuleType,
SegmentMetadataT,
)


class EnvironmentContext(TypedDict):
Expand Down Expand Up @@ -58,16 +63,16 @@ class FeatureContext(TypedDict):
priority: NotRequired[float]


class SegmentContext(TypedDict):
class SegmentContext(TypedDict, Generic[SegmentMetadataT]):
key: str
name: str
rules: List[SegmentRule]
overrides: NotRequired[List[FeatureContext]]
metadata: NotRequired[Dict[str, Any]]
metadata: NotRequired[SegmentMetadataT]


class EvaluationContext(TypedDict):
class EvaluationContext(TypedDict, Generic[SegmentMetadataT]):
environment: EnvironmentContext
identity: NotRequired[Optional[IdentityContext]]
segments: NotRequired[Dict[str, SegmentContext]]
segments: NotRequired[Dict[str, SegmentContext[SegmentMetadataT]]]
features: NotRequired[Dict[str, FeatureContext]]
14 changes: 8 additions & 6 deletions flag_engine/result/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from __future__ import annotations

from typing import Any, Dict, List, TypedDict
from typing import Any, Dict, Generic, List

from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypedDict

from flag_engine.segments.types import SegmentMetadataT


class FlagResult(TypedDict):
Expand All @@ -17,12 +19,12 @@ class FlagResult(TypedDict):
reason: str


class SegmentResult(TypedDict):
class SegmentResult(TypedDict, Generic[SegmentMetadataT]):
key: str
name: str
metadata: NotRequired[Dict[str, Any]]
metadata: NotRequired[SegmentMetadataT]


class EvaluationResult(TypedDict):
class EvaluationResult(TypedDict, Generic[SegmentMetadataT]):
flags: Dict[str, FlagResult]
segments: List[SegmentResult]
segments: List[SegmentResult[SegmentMetadataT]]
29 changes: 18 additions & 11 deletions flag_engine/segments/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
)
from flag_engine.result.types import EvaluationResult, FlagResult, SegmentResult
from flag_engine.segments import constants
from flag_engine.segments.types import ConditionOperator, ContextValue, is_context_value
from flag_engine.segments.types import (
ConditionOperator,
ContextValue,
SegmentMetadataT,
is_context_value,
)
from flag_engine.segments.utils import escape_double_quotes, get_matching_function
from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids
from flag_engine.utils.semver import is_semver
Expand All @@ -32,14 +37,16 @@ class FeatureContextWithSegmentName(typing.TypedDict):
segment_name: str


def get_evaluation_result(context: EvaluationContext) -> EvaluationResult:
def get_evaluation_result(
context: EvaluationContext[SegmentMetadataT],
) -> EvaluationResult[SegmentMetadataT]:
"""
Get the evaluation result for a given context.

:param context: the evaluation context
:return: EvaluationResult containing the context, flags, and segments
"""
segments: list[SegmentResult] = []
segments: list[SegmentResult[SegmentMetadataT]] = []
flags: dict[str, FlagResult] = {}

segment_feature_contexts: dict[SupportsStr, FeatureContextWithSegmentName] = {}
Expand All @@ -48,7 +55,7 @@ def get_evaluation_result(context: EvaluationContext) -> EvaluationResult:
if not is_context_in_segment(context, segment_context):
continue

segment_result: SegmentResult = {
segment_result: SegmentResult[SegmentMetadataT] = {
"key": segment_context["key"],
"name": segment_context["name"],
}
Expand Down Expand Up @@ -152,8 +159,8 @@ def get_flag_result_from_feature_context(


def is_context_in_segment(
context: EvaluationContext,
segment_context: SegmentContext,
context: EvaluationContext[SegmentMetadataT],
segment_context: SegmentContext[SegmentMetadataT],
) -> bool:
return bool(rules := segment_context["rules"]) and all(
context_matches_rule(
Expand All @@ -164,7 +171,7 @@ def is_context_in_segment(


def context_matches_rule(
context: EvaluationContext,
context: EvaluationContext[SegmentMetadataT],
rule: SegmentRule,
segment_key: SupportsStr,
) -> bool:
Expand Down Expand Up @@ -194,7 +201,7 @@ def context_matches_rule(


def context_matches_condition(
context: EvaluationContext,
context: EvaluationContext[SegmentMetadataT],
condition: SegmentCondition,
segment_key: SupportsStr,
) -> bool:
Expand Down Expand Up @@ -255,7 +262,7 @@ def context_matches_condition(


def get_context_value(
context: EvaluationContext,
context: EvaluationContext[SegmentMetadataT],
property: str,
) -> ContextValue:
value = None
Expand Down Expand Up @@ -353,7 +360,7 @@ def inner(
@lru_cache
def _get_context_value_getter(
property: str,
) -> typing.Callable[[EvaluationContext], ContextValue]:
) -> typing.Callable[[EvaluationContext[SegmentMetadataT]], ContextValue]:
"""
Get a function to retrieve a context value based on property value,
assumed to be either a JSONPath string or a trait key.
Expand All @@ -370,7 +377,7 @@ def _get_context_value_getter(
f'$.identity.traits["{escape_double_quotes(property)}"]',
)

def getter(context: EvaluationContext) -> ContextValue:
def getter(context: EvaluationContext[SegmentMetadataT]) -> ContextValue:
if typing.TYPE_CHECKING: # pragma: no cover
# Ugly hack to satisfy mypy :(
data = dict(context)
Expand Down
8 changes: 6 additions & 2 deletions flag_engine/segments/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Any, Literal, Union, get_args
from __future__ import annotations

from typing_extensions import TypeGuard
from typing import Any, Dict, Literal, Union, get_args

from typing_extensions import TypeGuard, TypeVar

SegmentMetadataT = TypeVar("SegmentMetadataT", default=Dict[str, object])

ConditionOperator = Literal[
"EQUAL",
Expand Down
Empty file removed flag_engine/types/__init__.py
Empty file.
84 changes: 84 additions & 0 deletions tests/unit/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import json
from typing import TYPE_CHECKING, TypedDict

if not TYPE_CHECKING:
# `reveal_type` is a pseudo-builtin only available when type checking.
# Define a no-op version here so that we can call it in the tests.
def reveal_type(x: object) -> None: ...


from flag_engine.context.types import EvaluationContext, IdentityContext, SegmentContext
from flag_engine.engine import get_evaluation_result
Expand Down Expand Up @@ -357,3 +364,80 @@ def test_get_evaluation_result__segment_override__no_priority__returns_expected(
{"key": "3", "name": "another_segment"},
],
}


def test_segment_metadata_generic_type__returns_expected() -> None:
# Given
class CustomMetadata(TypedDict):
foo: str
bar: int

segment_metadata = CustomMetadata(foo="hello", bar=123)

evaluation_context: EvaluationContext[CustomMetadata] = {
"environment": {"key": "api-key", "name": ""},
"segments": {
"1": {
"key": "1",
"name": "my_segment",
"rules": [
{
"type": "ALL",
"conditions": [
{
"property": "$.environment.name",
"operator": "EQUAL",
"value": "",
}
],
"rules": [],
}
],
"metadata": segment_metadata,
},
},
}

# When
result = get_evaluation_result(evaluation_context)

# Then
assert result["segments"][0]["metadata"] is segment_metadata
reveal_type(result["segments"][0]["metadata"]) # CustomMetadata


def test_segment_metadata_generic_type__default__returns_expected() -> None:
# Given
segment_metadata = {"hello": object()}

# we don't specify generic type, but mypy is happy with this
evaluation_context: EvaluationContext = {
"environment": {"key": "api-key", "name": ""},
"segments": {
"1": {
"key": "1",
"name": "my_segment",
"rules": [
{
"type": "ALL",
"conditions": [
{
"property": "$.environment.name",
"operator": "EQUAL",
"value": "",
}
],
"rules": [],
}
],
"metadata": segment_metadata,
},
},
}

# When
result = get_evaluation_result(evaluation_context)

# Then
assert result["segments"][0]["metadata"] is segment_metadata
reveal_type(result["segments"][0]["metadata"]) # Dict[str, object]