diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 60874149bf..c39785c226 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -10,6 +10,7 @@ Any, Callable, ForwardRef, + Literal, Optional, TypeVar, Union, @@ -218,6 +219,13 @@ def get_sa_type_from_type_annotation(annotation: Any) -> Any: # Optional unions are allowed use_type = bases[0] if bases[0] is not NoneType else bases[1] return get_sa_type_from_type_annotation(use_type) + if origin is Literal: + literal_args = get_args(annotation) + if all(isinstance(arg, bool) for arg in literal_args): # all bools + return bool + if all(isinstance(arg, int) for arg in literal_args): # all ints + return int + return str return origin def get_sa_type_from_field(field: Any) -> Any: @@ -469,6 +477,14 @@ def is_field_noneable(field: "FieldInfo") -> bool: return field.allow_none # type: ignore[no-any-return, attr-defined] def get_sa_type_from_field(field: Any) -> Any: + if get_origin(field.type_) is Literal: + literal_args = get_args(field.type_) + if all(isinstance(arg, bool) for arg in literal_args): # all bools + return bool + if all(isinstance(arg, int) for arg in literal_args): # all ints + return int + return str + if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: return field.type_ raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") diff --git a/tests/test_main.py b/tests/test_main.py index c1508d181f..fc3d12b291 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional import pytest from sqlalchemy.exc import IntegrityError @@ -125,3 +125,50 @@ class Hero(SQLModel, table=True): # The next statement should not raise an AttributeError assert hero_rusty_man.team assert hero_rusty_man.team.name == "Preventers" + + +def test_literal_str(clear_sqlmodel, caplog): + """Test https://github.com/fastapi/sqlmodel/issues/57""" + + class Model(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + all_str: Literal["a", "b", "c"] + mixed: Literal["yes", "no", 1, 0] + all_int: Literal[1, 2, 3] + int_bool: Literal[0, 1, True, False] + all_bool: Literal[True, False] + + obj = Model( + all_str="a", + mixed="yes", + all_int=1, + int_bool=True, + all_bool=False, + ) + + engine = create_engine("sqlite://", echo=True) + + SQLModel.metadata.create_all(engine) + + # Check DDL + assert "all_str VARCHAR NOT NULL" in caplog.text + assert "mixed VARCHAR NOT NULL" in caplog.text + assert "all_int INTEGER NOT NULL" in caplog.text + assert "int_bool INTEGER NOT NULL" in caplog.text + assert "all_bool BOOLEAN NOT NULL" in caplog.text + + # Check query + with Session(engine) as session: + session.add(obj) + session.commit() + session.refresh(obj) + assert isinstance(obj.all_str, str) + assert obj.all_str == "a" + assert isinstance(obj.mixed, str) + assert obj.mixed == "yes" + assert isinstance(obj.all_int, int) + assert obj.all_int == 1 + assert isinstance(obj.int_bool, int) + assert obj.int_bool == 1 + assert isinstance(obj.all_bool, bool) + assert obj.all_bool is False