Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ API

blueapi

.. autodata:: blueapi.service.main.REST_API_VERSION
.. autoattribute:: blueapi.config.ApplicationConfig.REST_API_VERSION
34 changes: 33 additions & 1 deletion src/blueapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functools import cached_property
from pathlib import Path
from string import Template
from typing import Annotated, Any, Generic, Literal, TypeVar, cast
from typing import Annotated, Any, ClassVar, Generic, Literal, TypeVar, cast

import requests
import yaml
Expand Down Expand Up @@ -260,12 +260,44 @@ class NumtrackerConfig(BlueapiBaseModel):
detector_file_template: str = "{instrument}-{scan_id}-{device_name}"


class Tag(StrEnum):
TASK = "Task"
PLAN = "Plan"
DEVICE = "Device"
ENV = "Environment"
META = "Meta"


class ApplicationConfig(BlueapiBaseModel):
"""
Config for the worker application as a whole. Root of
config tree.
"""

#: API version to publish in OpenAPI schema
REST_API_VERSION: ClassVar[str] = "1.1.3"

LICENSE_INFO: ClassVar[dict[str, str]] = {
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
}
CONTEXT_HEADER: ClassVar[str] = "traceparent"
VENDOR_CONTEXT_HEADER: ClassVar[str] = "tracestate"
AUTHORIZAITON_HEADER: ClassVar[str] = "authorization"
PROPAGATED_HEADERS: ClassVar[set[str]] = {
CONTEXT_HEADER,
VENDOR_CONTEXT_HEADER,
AUTHORIZAITON_HEADER,
}
DOCS_ENDPOINT: ClassVar[str] = "/docs"
TAG_METADATA: ClassVar[list[dict[str, str]]] = [
{"name": Tag.TASK, "description": "Endpoints related to tasks"},
{"name": Tag.PLAN, "description": "Endpoints to get plans"},
{"name": Tag.DEVICE, "description": "Endpoints to get devices"},
{"name": Tag.ENV, "description": "Endpoints related to server environment"},
{"name": Tag.META, "description": "Endpoints used for auxiliary functions"},
]

stomp: StompConfig = Field(default_factory=StompConfig)
tiled: TiledConfig = Field(default_factory=TiledConfig)
env: EnvironmentConfig = Field(default_factory=EnvironmentConfig)
Expand Down
59 changes: 17 additions & 42 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import urllib.parse
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from enum import StrEnum
from typing import Annotated

import jwt
Expand Down Expand Up @@ -34,7 +33,7 @@
from starlette.responses import JSONResponse
from super_state_machine.errors import TransitionError

from blueapi.config import ApplicationConfig, OIDCConfig
from blueapi.config import ApplicationConfig, OIDCConfig, Tag
from blueapi.service import interface
from blueapi.worker import TrackableTask, WorkerState
from blueapi.worker.event import TaskStatusEnum
Expand All @@ -57,38 +56,9 @@
)
from .runner import WorkerDispatcher

#: API version to publish in OpenAPI schema
REST_API_VERSION = "1.1.3"

LICENSE_INFO: dict[str, str] = {
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
}
RUNNER: WorkerDispatcher | None = None

LOGGER = logging.getLogger(__name__)
CONTEXT_HEADER = "traceparent"
VENDOR_CONTEXT_HEADER = "tracestate"
AUTHORIZAITON_HEADER = "authorization"
PROPAGATED_HEADERS = {CONTEXT_HEADER, VENDOR_CONTEXT_HEADER, AUTHORIZAITON_HEADER}
DOCS_ENDPOINT = "/docs"


class Tag(StrEnum):
TASK = "Task"
PLAN = "Plan"
DEVICE = "Device"
ENV = "Environment"
META = "Meta"


TAG_METADATA: list[dict[str, str]] = [
{"name": Tag.TASK, "description": "Endpoints related to tasks"},
{"name": Tag.PLAN, "description": "Endpoints to get plans"},
{"name": Tag.DEVICE, "description": "Endpoints to get devices"},
{"name": Tag.ENV, "description": "Endpoints related to server environment"},
{"name": Tag.META, "description": "Endpoints used for auxiliary functions"},
]


def _runner() -> WorkerDispatcher:
Expand Down Expand Up @@ -133,14 +103,14 @@ async def inner(app: FastAPI):

def get_app(config: ApplicationConfig):
app = FastAPI(
docs_url=DOCS_ENDPOINT,
docs_url=ApplicationConfig.DOCS_ENDPOINT,
title="BlueAPI Control",
summary="BlueAPI wraps bluesky plans and devices and "
"exposes endpoints to send commands/receive data",
lifespan=lifespan(config),
version=REST_API_VERSION,
license_info=LICENSE_INFO,
openapi_tags=TAG_METADATA,
version=ApplicationConfig.REST_API_VERSION,
license_info=ApplicationConfig.LICENSE_INFO,
openapi_tags=ApplicationConfig.TAG_METADATA,
)
dependencies = []
if config.oidc:
Expand Down Expand Up @@ -210,7 +180,8 @@ async def on_token_error_401(_: Request, __: Exception):
def root_redirect() -> RedirectResponse:
"""Redirect to docs url"""
return RedirectResponse(
status_code=status.HTTP_307_TEMPORARY_REDIRECT, url=DOCS_ENDPOINT
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
url=ApplicationConfig.DOCS_ENDPOINT,
)


Expand Down Expand Up @@ -410,7 +381,7 @@ def get_passthrough_headers(request: Request) -> dict[str, str]:
return {
key: value
for key, value in request.headers.items()
if key.casefold() in PROPAGATED_HEADERS
if key.casefold() in ApplicationConfig.PROPAGATED_HEADERS
}


Expand Down Expand Up @@ -590,7 +561,7 @@ async def add_api_version_header(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
):
response = await call_next(request)
response.headers["X-API-Version"] = REST_API_VERSION
response.headers["X-API-Version"] = ApplicationConfig.REST_API_VERSION
return response


Expand All @@ -613,10 +584,14 @@ async def inject_propagated_observability_context(
HTTP headers and attach it to the local one.
"""
headers = request.headers
if CONTEXT_HEADER in headers:
carrier = {CONTEXT_HEADER: headers[CONTEXT_HEADER]}
if VENDOR_CONTEXT_HEADER in headers:
carrier[VENDOR_CONTEXT_HEADER] = headers[VENDOR_CONTEXT_HEADER]
if ApplicationConfig.CONTEXT_HEADER in headers:
carrier = {
ApplicationConfig.CONTEXT_HEADER: headers[ApplicationConfig.CONTEXT_HEADER]
}
if ApplicationConfig.VENDOR_CONTEXT_HEADER in headers:
carrier[ApplicationConfig.VENDOR_CONTEXT_HEADER] = headers[
ApplicationConfig.VENDOR_CONTEXT_HEADER
]
ctx = get_global_textmap().extract(carrier)

attach(ctx)
Expand Down
4 changes: 2 additions & 2 deletions src/blueapi/service/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pyparsing import Any

from blueapi.config import ApplicationConfig
from blueapi.service.main import LICENSE_INFO, get_app
from blueapi.service.main import get_app

DOCS_SCHEMA_LOCATION = Path(__file__).parents[3] / "docs" / "reference" / "openapi.yaml"

Expand All @@ -21,7 +21,7 @@ def generate_schema() -> Mapping[str, Any]:
openapi_version=app.openapi_version,
description=app.description,
routes=app.routes,
license_info=LICENSE_INFO,
license_info=ApplicationConfig.LICENSE_INFO,
)


Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def test_docs_redirect(
):
client_with_auth.follow_redirects = False
response = client_with_auth.get("/")
assert response.headers.get("location") == main.DOCS_ENDPOINT
assert response.headers.get("location") == ApplicationConfig.DOCS_ENDPOINT
assert response.status_code == status.HTTP_307_TEMPORARY_REDIRECT


Expand Down