diff --git a/src/diffcalc_API/errors/hkl.py b/src/diffcalc_API/errors/hkl.py index 0738d69..f27ef5a 100644 --- a/src/diffcalc_API/errors/hkl.py +++ b/src/diffcalc_API/errors/hkl.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np from diffcalc_API.errors.definitions import ( @@ -16,8 +18,10 @@ class ErrorCodes(ErrorCodesBase): class InvalidMillerIndicesError(DiffcalcAPIException): - def __init__(self) -> None: - self.detail = "At least one of the hkl indices must be non-zero" + def __init__(self, detail: Optional[str] = None) -> None: + self.detail = ( + "At least one of the hkl indices must be non-zero" if not detail else detail + ) self.status_code = ErrorCodes.INVALID_MILLER_INDICES diff --git a/src/diffcalc_API/examples/ub.py b/src/diffcalc_API/examples/ub.py index 6e22730..491638c 100644 --- a/src/diffcalc_API/examples/ub.py +++ b/src/diffcalc_API/examples/ub.py @@ -3,13 +3,16 @@ AddReflectionParams, EditOrientationParams, EditReflectionParams, + HklModel, + PositionModel, SetLatticeParams, + XyzModel, ) add_reflection: AddReflectionParams = AddReflectionParams( **{ - "hkl": [0, 0, 1], - "position": [7.31, 0.0, 10.62, 0, 0.0, 0], + "hkl": HklModel(h=0, k=0, l=1), + "position": PositionModel(mu=7.31, delta=0.0, nu=10.62, eta=0, chi=0.0, phi=0), "energy": 12.39842, "tag": "refl1", } @@ -21,15 +24,15 @@ add_orientation: AddOrientationParams = AddOrientationParams( **{ - "hkl": [0, 1, 0], - "xyz": [0, 1, 0], + "hkl": HklModel(h=0, k=1, l=0), + "xyz": XyzModel(x=0, y=1, z=0), "tag": "plane", } ) edit_orientation: EditOrientationParams = EditOrientationParams( **{ - "hkl": (0, 1, 0), + "hkl": HklModel(h=0, k=1, l=0), "tag_or_idx": "plane", } ) diff --git a/src/diffcalc_API/models/ub.py b/src/diffcalc_API/models/ub.py index cb79b95..e6ce562 100644 --- a/src/diffcalc_API/models/ub.py +++ b/src/diffcalc_API/models/ub.py @@ -1,8 +1,29 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Union from pydantic import BaseModel +class HklModel(BaseModel): + h: float + k: float + l: float + + +class XyzModel(BaseModel): + x: float + y: float + z: float + + +class PositionModel(BaseModel): + mu: float + delta: float + nu: float + eta: float + chi: float + phi: float + + class SetLatticeParams(BaseModel): system: Optional[Union[str, float]] = None a: Optional[float] = None @@ -14,32 +35,30 @@ class SetLatticeParams(BaseModel): class AddReflectionParams(BaseModel): - hkl: Tuple[float, float, float] - position: Tuple[ - float, float, float, float, float, float - ] # allows easier user input + hkl: HklModel + position: PositionModel energy: float tag: Optional[str] = None class AddOrientationParams(BaseModel): - hkl: Tuple[float, float, float] - xyz: Tuple[float, float, float] - position: Optional[Tuple[float, float, float, float, float, float]] = None + hkl: HklModel + xyz: XyzModel + position: Optional[PositionModel] = None tag: Optional[str] = None class EditReflectionParams(BaseModel): - hkl: Optional[Tuple[float, float, float]] = None - position: Optional[Tuple[float, float, float, float, float, float]] = None + hkl: Optional[HklModel] = None + position: Optional[PositionModel] = None energy: Optional[float] = None tag_or_idx: Union[int, str] class EditOrientationParams(BaseModel): - hkl: Optional[Tuple[float, float, float]] = None - xyz: Optional[Tuple[float, float, float]] = None - position: Optional[Tuple[float, float, float, float, float, float]] = None + hkl: Optional[HklModel] = None + xyz: Optional[XyzModel] = None + position: Optional[PositionModel] = None tag_or_idx: Union[int, str] diff --git a/src/diffcalc_API/routes/hkl.py b/src/diffcalc_API/routes/hkl.py index 58bd627..482de4d 100644 --- a/src/diffcalc_API/routes/hkl.py +++ b/src/diffcalc_API/routes/hkl.py @@ -1,17 +1,14 @@ -from typing import Optional, Tuple, Union +from typing import List, Optional, Union from fastapi import APIRouter, Depends, Query +from diffcalc_API.models.ub import HklModel, PositionModel from diffcalc_API.services import hkl as service from diffcalc_API.stores.protocol import HklCalcStore, get_store router = APIRouter(prefix="/hkl", tags=["hkl"]) -SingleConstraint = Union[Tuple[str, float], str] -PositionType = Tuple[float, float, float] - - @router.get("/{name}/UB") async def calculate_ub( name: str, @@ -27,7 +24,8 @@ async def calculate_ub( @router.get("/{name}/position/lab") async def lab_position_from_miller_indices( name: str, - miller_indices: Tuple[float, float, float] = Query(example=[0, 0, 1]), + # miller_indices: List[float] = Query(example=[0, 0, 1]), + miller_indices: HklModel = Depends(), wavelength: float = Query(..., example=1.0), store: HklCalcStore = Depends(get_store), collection: Optional[str] = Query(default=None, example="B07"), @@ -42,8 +40,9 @@ async def lab_position_from_miller_indices( @router.get("/{name}/position/hkl") async def miller_indices_from_lab_position( name: str, - pos: Tuple[float, float, float, float, float, float] = Query( - ..., example=[7.31, 0, 10.62, 0, 0, 0] + pos: PositionModel = Depends( + # ..., example={"mu": 7.31, "delta": 0, "nu": 10.62, + # "eta": 0, "chi": 0, "phi": 0} ), wavelength: float = Query(..., example=1.0), store: HklCalcStore = Depends(get_store), @@ -58,9 +57,9 @@ async def miller_indices_from_lab_position( @router.get("/{name}/scan/hkl") async def scan_hkl( name: str, - start: PositionType = Query(..., example=(1, 0, 1)), - stop: PositionType = Query(..., example=(2, 0, 2)), - inc: PositionType = Query(..., example=(0.1, 0, 0.1)), + start: List[float] = Query(..., example=[1, 0, 1]), + stop: List[float] = Query(..., example=[2, 0, 2]), + inc: List[float] = Query(..., example=(0.1, 0, 0.1)), wavelength: float = Query(..., example=1), store: HklCalcStore = Depends(get_store), collection: Optional[str] = Query(default=None, example="B07"), @@ -77,7 +76,8 @@ async def scan_wavelength( start: float = Query(..., example=1.0), stop: float = Query(..., example=2.0), inc: float = Query(..., example=0.2), - hkl: PositionType = Query(..., example=(1, 0, 1)), + # hkl: PositionType = Query(..., example=(1, 0, 1)), + hkl: HklModel = Depends(), store: HklCalcStore = Depends(get_store), collection: Optional[str] = Query(default=None, example="B07"), ): @@ -94,7 +94,8 @@ async def scan_constraint( start: float = Query(..., example=1), stop: float = Query(..., example=4), inc: float = Query(..., example=1), - hkl: PositionType = Query(..., example=(1, 0, 1)), + # hkl: PositionType = Query(..., example=(1, 0, 1)), + hkl: HklModel = Depends(), wavelength: float = Query(..., example=1.0), store: HklCalcStore = Depends(get_store), collection: Optional[str] = Query(default=None, example="B07"), diff --git a/src/diffcalc_API/routes/ub.py b/src/diffcalc_API/routes/ub.py index 0f41bb0..c077ad6 100644 --- a/src/diffcalc_API/routes/ub.py +++ b/src/diffcalc_API/routes/ub.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional from fastapi import APIRouter, Body, Depends, Query @@ -11,6 +11,7 @@ DeleteParams, EditOrientationParams, EditReflectionParams, + HklModel, SetLatticeParams, ) from diffcalc_API.services import ub as service @@ -152,7 +153,7 @@ async def set_lattice( async def modify_property( name: str, property: str, - target_value: Tuple[float, float, float] = Body(..., example=[1, 0, 0]), + target_value: HklModel = Body(..., example={"h": 1, "k": 0, "l": 0}), store: HklCalcStore = Depends(get_store), collection: Optional[str] = Query(default=None, example="B07"), ): diff --git a/src/diffcalc_API/server.py b/src/diffcalc_API/server.py index b64b3f1..1d851ca 100644 --- a/src/diffcalc_API/server.py +++ b/src/diffcalc_API/server.py @@ -1,3 +1,4 @@ +import traceback from typing import Optional from diffcalc.util import DiffcalcException @@ -46,6 +47,8 @@ async def server_exceptions_middleware(request: Request, call_next): return await call_next(request) except Exception as e: # you probably want some kind of logging here + tb = traceback.format_exc() + print(tb) return responses.JSONResponse( status_code=500, diff --git a/src/diffcalc_API/services/hkl.py b/src/diffcalc_API/services/hkl.py index a8c17d8..cb6a9ca 100644 --- a/src/diffcalc_API/services/hkl.py +++ b/src/diffcalc_API/services/hkl.py @@ -1,18 +1,17 @@ from itertools import product -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np from diffcalc.hkl.geometry import Position from diffcalc_API.errors.hkl import InvalidMillerIndicesError, InvalidScanBoundsError +from diffcalc_API.models.ub import HklModel, PositionModel from diffcalc_API.stores.protocol import HklCalcStore -PositionType = Tuple[float, float, float] - async def lab_position_from_miller_indices( name: str, - miller_indices: Tuple[float, float, float], + miller_indices: HklModel, wavelength: float, store: HklCalcStore, collection: Optional[str], @@ -22,33 +21,39 @@ async def lab_position_from_miller_indices( if all([idx == 0 for idx in miller_indices]): raise InvalidMillerIndicesError() - all_positions = hklcalc.get_position(*miller_indices, wavelength) + all_positions = hklcalc.get_position(*miller_indices.dict().values(), wavelength) return combine_lab_position_results(all_positions) async def miller_indices_from_lab_position( name: str, - pos: Tuple[float, float, float, float, float, float], + pos: PositionModel, wavelength: float, store: HklCalcStore, collection: Optional[str], -) -> Tuple[Any, ...]: +) -> HklModel: hklcalc = await store.load(name, collection) - position = hklcalc.get_hkl(Position(*pos), wavelength) - return tuple(np.round(position, 16)) + hkl = np.round(hklcalc.get_hkl(Position(**pos.dict()), wavelength), 16) + return HklModel(h=hkl[0], k=hkl[1], l=hkl[2]) async def scan_hkl( name: str, - start: PositionType, - stop: PositionType, - inc: PositionType, + start: List[float], + stop: List[float], + inc: List[float], wavelength: float, store: HklCalcStore, collection: Optional[str], ) -> Dict[str, List[Dict[str, float]]]: hklcalc = await store.load(name, collection) + + if (len(start) != 3) or (len(stop) != 3) or (len(inc) != 3): + raise InvalidMillerIndicesError( + detail="start, stop and inc must have three floats for each miller index." + ) + axes_values = [ generate_axis(start[i], stop[i], inc[i]) if inc[i] != 0 else [0] for i in range(3) @@ -71,7 +76,7 @@ async def scan_wavelength( start: float, stop: float, inc: float, - hkl: PositionType, + hkl: HklModel, store: HklCalcStore, collection: Optional[str], ) -> Dict[str, List[Dict[str, float]]]: @@ -84,7 +89,7 @@ async def scan_wavelength( result = {} for wavelength in wavelengths: - all_positions = hklcalc.get_position(*hkl, wavelength) + all_positions = hklcalc.get_position(*hkl.dict().values(), wavelength) result[f"{wavelength}"] = combine_lab_position_results(all_positions) return result @@ -96,7 +101,7 @@ async def scan_constraint( start: float, stop: float, inc: float, - hkl: PositionType, + hkl: HklModel, wavelength: float, store: HklCalcStore, collection: Optional[str], @@ -109,7 +114,7 @@ async def scan_constraint( result = {} for value in np.arange(start, stop + inc, inc): setattr(hklcalc, constraint, value) - all_positions = hklcalc.get_position(*hkl, wavelength) + all_positions = hklcalc.get_position(*hkl.dict().values(), wavelength) result[f"{value}"] = combine_lab_position_results(all_positions) return result diff --git a/src/diffcalc_API/services/ub.py b/src/diffcalc_API/services/ub.py index 0959f05..e61ff37 100644 --- a/src/diffcalc_API/services/ub.py +++ b/src/diffcalc_API/services/ub.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Union from diffcalc.hkl.geometry import Position @@ -8,6 +8,7 @@ AddReflectionParams, EditOrientationParams, EditReflectionParams, + HklModel, SetLatticeParams, ) from diffcalc_API.stores.protocol import HklCalcStore @@ -28,8 +29,8 @@ async def add_reflection( hklcalc = await store.load(name, collection) hklcalc.ubcalc.add_reflection( - params.hkl, - Position(*params.position), + tuple(params.hkl.dict().values()), + Position(**params.position.dict()), params.energy, params.tag, ) @@ -50,10 +51,13 @@ async def edit_reflection( except (IndexError, ValueError): raise ReferenceRetrievalError(params.tag_or_idx, "reflection") + # TODO: make this more readable... hklcalc.ubcalc.edit_reflection( params.tag_or_idx, - params.hkl if params.hkl else (reflection.h, reflection.k, reflection.l), - Position(params.position) if params.position else reflection.pos, + tuple(params.hkl.dict().values()) + if params.hkl + else (reflection.h, reflection.k, reflection.l), + Position(params.position.dict()) if params.position else reflection.pos, params.energy if params.energy else reflection.energy, params.tag_or_idx if isinstance(params.tag_or_idx, str) else None, ) @@ -87,10 +91,10 @@ async def add_orientation( ) -> None: hklcalc = await store.load(name, collection) - position = Position(*params.position) if params.position else None + position = Position(*params.position.dict()) if params.position else None hklcalc.ubcalc.add_orientation( - params.hkl, - params.xyz, + tuple(params.hkl.dict().values()), + tuple(params.xyz.dict().values()), position, params.tag, ) @@ -113,9 +117,13 @@ async def edit_orientation( hklcalc.ubcalc.edit_orientation( params.tag_or_idx, - params.hkl if params.hkl else (orientation.h, orientation.k, orientation.l), - params.xyz if params.xyz else (orientation.x, orientation.y, orientation.z), - Position(params.position) if params.position else orientation.pos, + tuple(params.hkl.dict().values()) + if params.hkl + else (orientation.h, orientation.k, orientation.l), + tuple(params.xyz.dict().values()) + if params.xyz + else (orientation.x, orientation.y, orientation.z), + Position(params.position.dict()) if params.position else orientation.pos, params.tag_or_idx if isinstance(params.tag_or_idx, str) else None, ) @@ -153,12 +161,12 @@ async def set_lattice( async def modify_property( name: str, property: str, - target_value: Tuple[float, float, float], + target_value: HklModel, store: HklCalcStore, collection: Optional[str], ) -> None: hklcalc = await store.load(name, collection) - setattr(hklcalc.ubcalc, property, target_value) + setattr(hklcalc.ubcalc, property, tuple(target_value.dict().values())) await store.save(name, hklcalc, collection) diff --git a/tests/test_hklcalc.py b/tests/test_hklcalc.py index 138b2b7..43cb99e 100644 --- a/tests/test_hklcalc.py +++ b/tests/test_hklcalc.py @@ -40,7 +40,7 @@ def client() -> TestClient: def test_miller_indices_stay_the_same_after_transformation(client: TestClient): lab_positions = client.get( "/hkl/test/position/lab", - params={"miller_indices": [0, 0, 1], "wavelength": 1}, + params={"h": 0, "k": 0, "l": 1, "wavelength": 1}, ) assert lab_positions.status_code == 200 @@ -50,22 +50,21 @@ def test_miller_indices_stay_the_same_after_transformation(client: TestClient): miller_positions = client.get( "/hkl/test/position/hkl", params={ - "pos": [ - pos["mu"], - pos["delta"], - pos["nu"], - pos["eta"], - pos["chi"], - pos["phi"], - ], + "mu": pos["mu"], + "delta": pos["delta"], + "nu": pos["nu"], + "eta": pos["eta"], + "chi": pos["chi"], + "phi": pos["phi"], "wavelength": 1, }, ) assert miller_positions.status_code == 200 - assert np.all( - np.round(miller_positions.json()["payload"], 8) == np.array([0, 0, 1]) - ) + results = miller_positions.json()["payload"] + assert np.round(results["h"], 8) == 0 + assert np.round(results["k"], 8) == 0 + assert np.round(results["l"], 8) == 1 def test_scan_hkl( @@ -86,6 +85,25 @@ def test_scan_hkl( assert len(scan_results.keys()) == 9 +def test_scan_hkl_raises_invalid_miller_indices_error_for_wrong_inputs( + client: TestClient, +): + lab_positions = client.get( + "/hkl/test/scan/hkl", + params={ + "start": [1, 0, 1], + "stop": [2, 0, 2], + "inc": [0.5, 0], + "wavelength": 1, + }, + ) + + assert ( + ast.literal_eval(lab_positions.content.decode())["type"] + == "" + ) + + def test_scan_wavelength( client: TestClient, ): @@ -95,7 +113,9 @@ def test_scan_wavelength( "start": 1, "stop": 2, "inc": 0.5, - "hkl": [1, 0, 1], + "h": 1, + "k": 0, + "l": 1, }, ) scan_results = lab_positions.json()["payload"] @@ -113,7 +133,9 @@ def test_scan_constraint( "start": 1, "stop": 2, "inc": 0.5, - "hkl": [1, 0, 1], + "h": 1, + "k": 0, + "l": 1, "wavelength": 1.0, }, ) @@ -142,7 +164,9 @@ def test_invalid_scans(client: TestClient): "start": 1, "stop": 2, "inc": -0.5, - "hkl": [1, 0, 1], + "h": 1, + "k": 0, + "l": 1, }, ) diff --git a/tests/test_ubcalc.py b/tests/test_ubcalc.py index 7bd525a..365750d 100644 --- a/tests/test_ubcalc.py +++ b/tests/test_ubcalc.py @@ -55,8 +55,8 @@ def test_add_reflection(client: TestClient): response = client.post( "/ub/test/reflection", json={ - "hkl": [0, 0, 1], - "position": [7, 0, 10, 0, 0, 0], + "hkl": {"h": 0, "k": 0, "l": 1}, + "position": {"mu": 7, "delta": 0, "nu": 10, "eta": 0, "chi": 0, "phi": 0}, "energy": 12, "tag": "foo", }, @@ -117,8 +117,8 @@ def test_add_orientation(client: TestClient): response = client.post( "/ub/test/orientation", json={ - "hkl": [0, 1, 0], - "xyz": [0, 1, 0], + "hkl": {"h": 0, "k": 1, "l": 0}, + "xyz": {"x": 0, "y": 1, "z": 0}, "tag": "bar", }, ) @@ -134,7 +134,7 @@ def test_edit_orientation(client: TestClient): response = client.put( "/ub/test/orientation", json={ - "xyz": [1, 1, 0], + "xyz": {"x": 1, "y": 1, "z": 0}, "tag_or_idx": "bar", }, ) @@ -167,7 +167,7 @@ def test_edit_or_delete_orientation_fails_for_non_existing_orientation( edit_response = client.put( "/ub/test/orientation", json={ - "xyz": [1, 1, 0], + "xyz": {"x": 1, "y": 1, "z": 0}, "tag_or_idx": "bar", }, ) @@ -210,7 +210,7 @@ def test_set_lattice_fails_for_empty_data(client: TestClient): def test_modify_property(client: TestClient): response = client.put( "/ub/test/n_hkl", - json=[0, 0, 1], + json={"h": 0, "k": 0, "l": 1}, ) assert response.status_code == 200 @@ -220,6 +220,6 @@ def test_modify_property(client: TestClient): def test_modify_non_existent_property(client: TestClient): response = client.put( "/ub/test/silly_property", - json=[0, 0, 1], + json={"h": 0, "k": 0, "l": 1}, ) assert response.status_code == ErrorCodes.INVALID_PROPERTY