diff --git a/examples/fodo.py b/examples/fodo.py index 5d8430e..5757603 100644 --- a/examples/fodo.py +++ b/examples/fodo.py @@ -6,10 +6,10 @@ # Add the parent directory to sys.path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../src/"))) -from pals.MagneticMultipoleParameters import MagneticMultipoleParameters -from pals.Drift import Drift -from pals.Quadrupole import Quadrupole -from pals.BeamLine import BeamLine +from pals import MagneticMultipoleParameters +from pals import Drift +from pals import Quadrupole +from pals import BeamLine def main(): diff --git a/src/pals/BeamLine.py b/src/pals/BeamLine.py index 76c271b..433df35 100644 --- a/src/pals/BeamLine.py +++ b/src/pals/BeamLine.py @@ -1,10 +1,10 @@ from pydantic import ConfigDict, Field, model_validator from typing import Annotated, List, Literal, Union -from pals.BaseElement import BaseElement -from pals.ThickElement import ThickElement -from pals.Drift import Drift -from pals.Quadrupole import Quadrupole +from .BaseElement import BaseElement +from .ThickElement import ThickElement +from .Drift import Drift +from .Quadrupole import Quadrupole class BeamLine(BaseElement): diff --git a/src/pals/__init__.py b/src/pals/__init__.py index e69de29..1433652 100644 --- a/src/pals/__init__.py +++ b/src/pals/__init__.py @@ -0,0 +1,22 @@ +"""Top-level package for PALS. + +Re-export commonly used classes from submodules so callers can use +simpler import statements like `from pals import Drift` instead of +`from pals.Drift import Drift`. +""" + +from .BaseElement import BaseElement +from .BeamLine import BeamLine +from .Drift import Drift +from .MagneticMultipoleParameters import MagneticMultipoleParameters +from .Quadrupole import Quadrupole +from .ThickElement import ThickElement + +__all__ = [ + "BaseElement", + "BeamLine", + "Drift", + "MagneticMultipoleParameters", + "Quadrupole", + "ThickElement", +] diff --git a/tests/test_schema.py b/tests/test_schema.py index 34202be..74e271b 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -8,12 +8,12 @@ from pydantic import ValidationError -from pals.MagneticMultipoleParameters import MagneticMultipoleParameters -from pals.BaseElement import BaseElement -from pals.ThickElement import ThickElement -from pals.Drift import Drift -from pals.Quadrupole import Quadrupole -from pals.BeamLine import BeamLine +from pals import MagneticMultipoleParameters +from pals import BaseElement +from pals import ThickElement +from pals import Drift +from pals import Quadrupole +from pals import BeamLine def test_BaseElement():