diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 2c3ba04769..0c83e2f334 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -10,15 +10,21 @@ from bluesky.callbacks.best_effort import BestEffortCallback from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker +from click.exceptions import ClickException from observability_utils.tracing import setup_tracing from pydantic import ValidationError from requests.exceptions import ConnectionError -from blueapi import __version__ +from blueapi import __version__, config from blueapi.cli.format import OutputFormat from blueapi.client.client import BlueapiClient from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient -from blueapi.client.rest import BlueskyRemoteControlError +from blueapi.client.rest import ( + BlueskyRemoteControlError, + InvalidParameters, + UnauthorisedAccess, + UnknownPlan, +) from blueapi.config import ( ApplicationConfig, ConfigLoader, @@ -226,8 +232,10 @@ def run_plan( client: BlueapiClient = obj["client"] parameters = parameters or "{}" - task_id = "" - parsed_params = json.loads(parameters) if isinstance(parameters, str) else {} + try: + parsed_params = json.loads(parameters) if isinstance(parameters, str) else {} + except json.JSONDecodeError as jde: + raise ClickException(f"Parameters are not valid JSON: {jde}") from jde progress_bar = CliEventRenderer() callback = BestEffortCallback() @@ -240,18 +248,25 @@ def on_event(event: AnyEvent) -> None: try: task = Task(name=name, params=parsed_params) + except ValidationError as ve: + ip = InvalidParameters.from_validation_error(ve) + raise ClickException(ip.message()) from ip + + try: resp = client.run_task(task, on_event=on_event) - except ValidationError as e: - pprint(f"failed to validate the task parameters, {task_id}, error: {e}") - return + except config.MissingStompConfiguration as mse: + raise ClickException(*mse.args) from mse + except UnknownPlan as up: + raise ClickException(f"Plan '{name}' was not recognised") from up + except UnauthorisedAccess as ua: + raise ClickException("Unauthorised request") from ua + except InvalidParameters as ip: + raise ClickException(ip.message()) from ip except (BlueskyRemoteControlError, BlueskyStreamingError) as e: - pprint(f"server error with this message: {e}") - return - except ValueError: - pprint("task could not run") - return + raise ClickException(f"server error with this message: {e}") from e + except ValueError as ve: + raise ClickException(f"task could not run: {ve}") from ve - pprint(resp.model_dump()) if resp.task_status is not None and not resp.task_status.task_failed: print("Plan Succeeded") diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index ad89a424b1..681293b373 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -8,7 +8,7 @@ start_as_current_span, ) -from blueapi.config import ApplicationConfig +from blueapi.config import ApplicationConfig, MissingStompConfiguration from blueapi.core.bluesky_types import DataEvent from blueapi.service.authentication import SessionManager from blueapi.service.model import ( @@ -214,8 +214,8 @@ def run_task( """ if self._events is None: - raise RuntimeError( - "Cannot run plans without Stomp configuration to track progress" + raise MissingStompConfiguration( + "Stomp configuration required to run plans is missing or disabled" ) task_response = self.create_task(task) diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index fcbb550736..5a5dfc3b63 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -8,7 +8,7 @@ get_tracer, start_as_current_span, ) -from pydantic import TypeAdapter +from pydantic import BaseModel, TypeAdapter, ValidationError from blueapi.config import RestConfig from blueapi.service.authentication import JWTAuth, SessionManager @@ -32,9 +32,17 @@ TRACER = get_tracer("rest") +class UnauthorisedAccess(Exception): + pass + + class BlueskyRemoteControlError(Exception): - def __init__(self, message: str) -> None: - super().__init__(message) + pass + + +class BlueskyRequestError(Exception): + def __init__(self, code: int, message: str) -> None: + super().__init__(message, code) class NoContent(Exception): @@ -44,6 +52,53 @@ def __init__(self, target_type: type) -> None: super().__init__(target_type) +class ParameterError(BaseModel): + loc: list[str | int] + msg: str + type: str + input: Any + + def field(self): + return ".".join(str(p) for p in self.loc[2:] or self.loc) + + def __str__(self) -> str: + match self.type: + case "missing": + return f"Missing value for {self.field()!r}" + case "extra_forbidden": + return f"Unexpected field {self.field()!r}" + case _: + return ( + f"Invalid value {self.input!r} for field {self.field()}: {self.msg}" + ) + + +class InvalidParameters(Exception): + def __init__(self, errors: list[ParameterError]): + self.errors = errors + + def message(self): + msg = "Incorrect parameters supplied" + if self.errors: + msg += "\n " + "\n ".join(str(e) for e in self.errors) + return msg + + @classmethod + def from_validation_error(cls, ve: ValidationError): + return cls( + [ + ParameterError( + loc=list(e["loc"]), msg=e["msg"], type=e["type"], input=e["input"] + ) + for e in ve.errors() + ] + ) + + +class UnknownPlan(Exception): + pass + + def _exception(response: requests.Response) -> Exception | None: code = response.status_code if code < 400: @@ -51,7 +106,31 @@ def _exception(response: requests.Response) -> Exception | None: elif code == 404: return KeyError(str(response.json())) else: - return BlueskyRemoteControlError(str(response)) + return BlueskyRemoteControlError(code, str(response)) + + +def _create_task_exceptions(response: requests.Response) -> Exception | None: + code = response.status_code + if code < 400: + return None + elif code == 401 or code == 403: + return UnauthorisedAccess() + elif code == 404: + return UnknownPlan() + elif code == 422: + try: + content = response.json() + return InvalidParameters( + TypeAdapter(list[ParameterError]).validate_python( + content.get("detail", []) + ) + ) + except Exception: + # If the error can't be parsed into something sensible, return the + # raw text in a generic exception so at least it gets reported + return BlueskyRequestError(code, response.text) + else: + return BlueskyRequestError(code, response.text) class BlueapiRestClient: @@ -106,6 +185,7 @@ def create_task(self, task: Task) -> TaskResponse: "/tasks", TaskResponse, method="POST", + get_exception=_create_task_exceptions, data=task.model_dump(), ) diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 01edca7fe3..2ad2af4fcd 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -291,3 +291,7 @@ def load(self) -> C: raise InvalidConfigError( f"Something is wrong with the configuration file: \n {error_details}" ) from exc + + +class MissingStompConfiguration(Exception): + pass diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 83998ad688..f60fe07a57 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -13,6 +13,7 @@ BlueskyRemoteControlError, ) from blueapi.client.event_bus import AnyEvent +from blueapi.client.rest import UnknownPlan from blueapi.config import ( ApplicationConfig, OIDCConfig, @@ -186,7 +187,7 @@ def test_create_task_and_delete_task_by_id(client: BlueapiClient): def test_create_task_validation_error(client: BlueapiClient): - with pytest.raises(KeyError, match="{'detail': 'Item not found'}"): + with pytest.raises(UnknownPlan): client.create_task(Task(name="Not-exists", params={"Not-exists": 0.0})) diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index 6337d1148a..3e4938150f 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -12,6 +12,7 @@ from blueapi.client.client import BlueapiClient from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError +from blueapi.config import MissingStompConfiguration from blueapi.core import DataEvent from blueapi.service.model import ( DeviceModel, @@ -380,8 +381,8 @@ def test_resume( def test_cannot_run_task_without_message_bus(client: BlueapiClient): with pytest.raises( - RuntimeError, - match="Cannot run plans without Stomp configuration to track progress", + MissingStompConfiguration, + match="Stomp configuration required to run plans is missing or disabled", ): client.run_task(Task(name="foo")) @@ -639,8 +640,8 @@ def test_cannot_run_task_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient ): with pytest.raises( - RuntimeError, - match="Cannot run plans without Stomp configuration to track progress", + MissingStompConfiguration, + match="Stomp configuration required to run plans is missing or disabled", ): with asserting_span_exporter(exporter, "grun_task"): client.run_task(Task(name="foo")) diff --git a/tests/unit_tests/client/test_rest.py b/tests/unit_tests/client/test_rest.py index 0d1b4e32f9..808db3754c 100644 --- a/tests/unit_tests/client/test_rest.py +++ b/tests/unit_tests/client/test_rest.py @@ -3,9 +3,19 @@ from unittest.mock import Mock, patch import pytest +import requests import responses -from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError +from blueapi.client.rest import ( + BlueapiRestClient, + BlueskyRemoteControlError, + BlueskyRequestError, + InvalidParameters, + ParameterError, + UnauthorisedAccess, + UnknownPlan, + _create_task_exceptions, +) from blueapi.config import OIDCConfig from blueapi.service.authentication import SessionCacheManager, SessionManager from blueapi.service.model import EnvironmentResponse @@ -49,6 +59,53 @@ def test_rest_error_code( rest.get_plans() +@pytest.mark.parametrize( + "code,content,expected_exception", + [ + (200, None, None), + (401, None, UnauthorisedAccess()), + (403, None, UnauthorisedAccess()), + (404, None, UnknownPlan()), + ( + 422, + """{ + "detail": [{ + "loc": ["body", "params", "foo"], + "type": "missing", + "msg": "missing value for foo", + "input": {} + }] + }""", + InvalidParameters( + [ + ParameterError( + loc=["body", "params", "foo"], + type="missing", + msg="missing value for foo", + input={}, + ) + ] + ), + ), + (450, "non-standard", BlueskyRequestError(450, "non-standard")), + (500, "internal_error", BlueskyRequestError(500, "internal_error")), + ], +) +def test_create_task_exceptions( + code: int, content: str | None, expected_exception: Exception +): + response = Mock(spec=requests.Response) + response.status_code = code + response.text = content + import json + + response.json.side_effect = lambda: json.loads(content) if content else None + err = _create_task_exceptions(response) + assert isinstance(err, type(expected_exception)) + if expected_exception is not None: + assert err.args == expected_exception.args + + def test_auth_request_functionality( rest_with_auth: BlueapiRestClient, mock_authn_server: responses.RequestsMock, @@ -99,3 +156,43 @@ def test_refresh_if_signature_expired( calls = mock_get_env.calls assert len(calls) == 1 assert calls[0].request.headers["Authorization"] == "Bearer new_token" + + +def test_parameter_error_field(): + p1 = ParameterError( + loc=["body", "parameters", "detectors", 0], + msg="error message", + type="error_type", + input="original_input", + ) + assert p1.field() == "detectors.0" + + +def test_parameter_error_missing_string(): + p1 = ParameterError( + loc=["body", "parameters", "field_one", 0], + msg="error_message", + type="missing", + input=None, + ) + assert str(p1) == "Missing value for 'field_one.0'" + + +def test_parameter_error_extra_string(): + p1 = ParameterError( + loc=["body", "parameters", "foo"], + msg="error_message", + type="extra_forbidden", + input={"foo": "bar"}, + ) + assert str(p1) == "Unexpected field 'foo'" + + +def test_parameter_error_other_string(): + p1 = ParameterError( + loc=["body", "parameters", "field_one", 0], + msg="error_message", + type="string_value", + input=34, + ) + assert str(p1) == "Invalid value 34 for field field_one.0: error_message" diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 7f9b4fbcbd..536fe91de5 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -17,7 +17,7 @@ from click.testing import CliRunner from opentelemetry import trace from ophyd_async.core import AsyncStatus -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from requests.exceptions import ConnectionError from responses import matchers from stomp.connect import StompConnection11 as Connection @@ -26,8 +26,18 @@ from blueapi.cli.cli import main from blueapi.cli.format import OutputFormat, fmt_dict from blueapi.client.event_bus import BlueskyStreamingError -from blueapi.client.rest import BlueskyRemoteControlError -from blueapi.config import ApplicationConfig, ScratchConfig, ScratchRepository +from blueapi.client.rest import ( + BlueskyRemoteControlError, + InvalidParameters, + ParameterError, + UnauthorisedAccess, + UnknownPlan, +) +from blueapi.config import ( + ApplicationConfig, + ScratchConfig, + ScratchRepository, +) from blueapi.core.bluesky_types import DataEvent, Plan from blueapi.service.model import ( DeviceModel, @@ -208,8 +218,8 @@ def test_submit_plan_without_stomp(runner: CliRunner): ) assert ( - str(result.exception) - == "Cannot run plans without Stomp configuration to track progress" + result.stdout + == "Error: Stomp configuration required to run plans is missing or disabled\n" ) @@ -222,10 +232,9 @@ def test_invalid_stomp_config_for_listener(runner: CliRunner): def test_cannot_run_plans_without_stomp_config(runner: CliRunner): result = runner.invoke(main, ["controller", "run", "sleep", '{"time": 5}']) assert result.exit_code == 1 - assert isinstance(result.exception, RuntimeError) assert ( - str(result.exception) - == "Cannot run plans without Stomp configuration to track progress" + result.stdout + == "Error: Stomp configuration required to run plans is missing or disabled\n" ) @@ -408,18 +417,37 @@ def test_env_reload_server_side_error(runner: CliRunner): @pytest.mark.parametrize( "exception, error_message", [ + (UnknownPlan(), "Error: Plan 'sleep' was not recognised\n"), + (UnauthorisedAccess(), "Error: Unauthorised request\n"), ( - ValidationError.from_exception_data(title="Base model", line_errors=[]), - "('failed to validate the task parameters, ," - + " error: 0 validation errors for '\n 'Base model\\n')\n", + InvalidParameters( + errors=[ + ParameterError( + loc=["body", "params", "foo"], + type="missing", + msg="Foo is missing", + input=None, + ) + ] + ), + "Error: Incorrect parameters supplied\n Missing value for 'foo'\n", ), ( BlueskyRemoteControlError("Server error"), - "'server error with this message: Server error'\n", + "Error: server error with this message: Server error\n", + ), + ( + ValueError("Error parsing parameters"), + "Error: task could not run: Error parsing parameters\n", ), - (ValueError("Error parsing parameters"), "'task could not run'\n"), ], - ids=["validation_error", "remote_control", "value_error"], + ids=[ + "unknown_plan", + "unauthorised_access", + "invalid_parameters", + "remote_control", + "value_error", + ], ) def test_error_handling(exception, error_message, runner: CliRunner): # Patching the create_task method to raise different exceptions @@ -437,8 +465,33 @@ def test_error_handling(exception, error_message, runner: CliRunner): '{"time": 5}', ], ) - # error message is printed to stderr but test runner combines output - assert result.stdout == error_message + # error message is printed to stderr but test runner combines output + assert result.stdout == error_message + assert result.exit_code == 1 + + +@pytest.mark.parametrize( + "params, error", + [ + ("{", "Parameters are not valid JSON"), + ("[]", ""), + ], +) +def test_run_task_parsing_errors(params: str, error: str, runner: CliRunner): + result = runner.invoke( + main, + [ + "-c", + "tests/unit_tests/example_yaml/valid_stomp_config.yaml", + "controller", + "run", + "sleep", + params, + ], + ) + # error message is printed to stderr but test runner combines output + assert result.stdout.startswith("Error: " + error) + assert result.exit_code == 1 def test_device_output_formatting():