diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index cc9db06a..fa728730 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -2178,6 +2178,46 @@ class PrimaryKey: field: PydanticFieldInfo +class PrimaryKeyAccessor: + """Descriptor that provides access to the primary key value. + + When a model uses a custom primary key field (e.g., `x: int = Field(primary_key=True)`), + this descriptor allows accessing the primary key value via `.pk` for consistency. + + This solves GitHub issue #570 where accessing `.pk` on a model with a custom + primary key returned an ExpressionProxy instead of the actual value. + """ + + def __get__(self, obj, objtype=None): + if obj is None: + # Class-level access - return ExpressionProxy for query building + # if the model is indexed, otherwise return the descriptor itself + if hasattr(objtype, "_meta") and hasattr(objtype._meta, "primary_key"): + pk_field = objtype._meta.primary_key.field + pk_name = objtype._meta.primary_key.name + # Return ExpressionProxy for query building (e.g., Model.pk == value) + return ExpressionProxy(pk_field, []) + return self + + # Instance-level access - return the actual primary key value + if hasattr(obj._meta, "primary_key") and obj._meta.primary_key is not None: + pk_name = obj._meta.primary_key.name + # Use __dict__ to get the instance value directly, avoiding descriptor recursion + if pk_name in obj.__dict__: + return obj.__dict__[pk_name] + # Fallback to getattr for computed/inherited attributes + return getattr(obj, pk_name) + return None + + def __set__(self, obj, value): + # When setting pk, set the actual primary key field + if hasattr(obj._meta, "primary_key") and obj._meta.primary_key is not None: + pk_name = obj._meta.primary_key.name + obj.__dict__[pk_name] = value + else: + obj.__dict__["pk"] = value + + if PYDANTIC_V2: class RedisOmConfig(ConfigDict): @@ -2354,6 +2394,41 @@ class Config: if is_primary_key: new_class._meta.primary_key = PrimaryKey(name=field_name, field=field) + # Count custom primary keys (not the default 'pk') to determine if we + # should set up the PrimaryKeyAccessor. We only do this when there's + # exactly one custom primary key. Multiple custom primary keys will be + # caught by validate_primary_key() later. + custom_pk_count = 0 + for field_name, field in model_fields.items(): + if field_name == "pk": + continue + # Check for primary key + check_field = field + if ( + not isinstance(field, FieldInfo) + and hasattr(field, "metadata") + and len(field.metadata) > 0 + and isinstance(field.metadata[0], FieldInfo) + ): + check_field = field.metadata[0] + if getattr(check_field, "primary_key", None) is True: + custom_pk_count += 1 + + # If there's exactly one custom primary key (not the default 'pk'), set up + # a PrimaryKeyAccessor so that .pk always returns the correct value. + # This fixes GitHub issue #570. + if ( + custom_pk_count == 1 + and hasattr(new_class._meta, "primary_key") + and new_class._meta.primary_key is not None + and new_class._meta.primary_key.name != "pk" + ): + # Remove 'pk' from model_fields since we have a custom primary key + if "pk" in model_fields: + model_fields.pop("pk") + # Set up PrimaryKeyAccessor descriptor for .pk access + setattr(new_class, "pk", PrimaryKeyAccessor()) + if not getattr(new_class._meta, "global_key_prefix", None): new_class._meta.global_key_prefix = getattr( base_meta, "global_key_prefix", "" @@ -2604,7 +2679,8 @@ def validate_primary_key(cls): if primary_keys == 0: raise RedisModelError("You must define a primary key for the model") elif primary_keys == 2: - cls.model_fields.pop("pk") + # Remove 'pk' from model_fields if it exists (may already be removed by ModelMeta) + cls.model_fields.pop("pk", None) elif primary_keys > 2: raise RedisModelError("You must define only one primary key for a model") diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index 5a010d4b..52bf7fe8 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -23,6 +23,7 @@ RedisModel, RedisModelError, ) +from aredis_om.model.model import ExpressionProxy # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. @@ -1535,3 +1536,36 @@ class Meta: await a2.save() r2 = await Attachment.get(a2.pk) assert r2.data == b"\x89PNG\x00\xff" + + + +@py_test_mark_asyncio +async def test_custom_primary_key_pk_property(key_prefix, redis): + """Test that .pk returns the actual value when using a custom primary key. + + Regression test for GitHub issue #570: accessing .pk on a model with + custom primary_key=True returned an ExpressionProxy instead of the value. + """ + + class ModelWithCustomPK(HashModel, index=True): + x: int = Field(primary_key=True) + name: str + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + instance = ModelWithCustomPK(x=42, name="test") + + # pk should return the actual value, not an ExpressionProxy + assert instance.pk == 42 + assert not isinstance(instance.pk, ExpressionProxy) + + # The custom field should also work for queries + await instance.save() + retrieved = await ModelWithCustomPK.get(42) + assert retrieved.pk == 42 + assert retrieved.x == 42 + assert retrieved.name == "test" diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 7375a57f..636d645f 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -25,6 +25,7 @@ RedisModel, RedisModelError, ) +from aredis_om.model.model import ExpressionProxy # We need to run this check as sync code (during tests) even in async mode # because we call it in the top-level module scope. @@ -1990,3 +1991,36 @@ class Meta: retrieved = await Document.get(doc.pk) assert retrieved.file.content == binary_content assert retrieved.file.mime_type == "image/png" + + + +@py_test_mark_asyncio +async def test_custom_primary_key_pk_property(key_prefix, redis): + """Test that .pk returns the actual value when using a custom primary key. + + Regression test for GitHub issue #570: accessing .pk on a model with + custom primary_key=True returned an ExpressionProxy instead of the value. + """ + + class ModelWithCustomPK(JsonModel, index=True): + x: int = Field(primary_key=True) + name: str + + class Meta: + global_key_prefix = key_prefix + database = redis + + await Migrator().run() + + instance = ModelWithCustomPK(x=42, name="test") + + # pk should return the actual value, not an ExpressionProxy + assert instance.pk == 42 + assert not isinstance(instance.pk, ExpressionProxy) + + # The custom field should also work for queries + await instance.save() + retrieved = await ModelWithCustomPK.get(42) + assert retrieved.pk == 42 + assert retrieved.x == 42 + assert retrieved.name == "test"