diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 024cb6abc1..bc9f7f8001 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,11 @@ Fixed * Fix redis SSL problems with sentinel #5660 +Added +~~~~~ + +* Added graceful shutdown for workflow engine. #5463 + Contributed by @khushboobhatia01 3.7.0 - May 05, 2022 -------------------- diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index bb6bbd8d42..d860951268 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -363,6 +363,8 @@ logging = /etc/st2/logging.timersengine.conf webui_base_url = https://localhost [workflow_engine] +# How long to wait for process (in seconds) to exit after receiving shutdown signal. +exit_still_active_check = 300 # Max seconds to allow workflow execution be idled before it is identified as orphaned and cancelled by the garbage collector. A value of zero means the feature is disabled. This is disabled by default. gc_max_idle_sec = 0 # Location of the logging configuration file. @@ -373,4 +375,6 @@ retry_max_jitter_msec = 1000 retry_stop_max_msec = 60000 # Interval inbetween retries. retry_wait_fixed_msec = 1000 +# Time interval between subsequent queries to check executions handled by WFE. +still_active_check_interval = 2 diff --git a/st2actions/st2actions/cmd/workflow_engine.py b/st2actions/st2actions/cmd/workflow_engine.py index e6eb65d5a8..ac6626afde 100644 --- a/st2actions/st2actions/cmd/workflow_engine.py +++ b/st2actions/st2actions/cmd/workflow_engine.py @@ -37,7 +37,6 @@ __all__ = ["main"] LOG = logging.getLogger(__name__) -WORKFLOW_ENGINE = "workflow_engine" def setup_sigterm_handler(engine): @@ -53,7 +52,7 @@ def sigterm_handler(signum=None, frame=None): def setup(): capabilities = {"name": "workflowengine", "type": "passive"} common_setup( - service=WORKFLOW_ENGINE, + service=workflows.WORKFLOW_ENGINE, config=config, setup_db=True, register_mq_exchanges=True, @@ -72,7 +71,7 @@ def run_server(): engine.start(wait=True) except (KeyboardInterrupt, SystemExit): LOG.info("(PID=%s) Workflow engine stopped.", os.getpid()) - deregister_service(service=WORKFLOW_ENGINE) + deregister_service(service=workflows.WORKFLOW_ENGINE) engine.shutdown() except: LOG.exception("(PID=%s) Workflow engine unexpectedly stopped.", os.getpid()) diff --git a/st2actions/st2actions/workflows/workflows.py b/st2actions/st2actions/workflows/workflows.py index 2151c7d440..700244d058 100644 --- a/st2actions/st2actions/workflows/workflows.py +++ b/st2actions/st2actions/workflows/workflows.py @@ -14,9 +14,13 @@ # limitations under the License. from __future__ import absolute_import +from oslo_config import cfg from orquesta import statuses - +from tooz.coordination import GroupNotCreated +from st2common.services import coordination +from eventlet.semaphore import Semaphore +from eventlet import spawn_after from st2common.constants import action as ac_const from st2common import log as logging from st2common.metrics import base as metrics @@ -24,12 +28,15 @@ from st2common.models.db import workflow as wf_db_models from st2common.persistence import liveaction as lv_db_access from st2common.persistence import workflow as wf_db_access +from st2common.persistence import execution as ex_db_access +from st2common.services import action as ac_svc from st2common.services import policies as pc_svc from st2common.services import workflows as wf_svc from st2common.transport import consumers from st2common.transport import queues from st2common.transport import utils as txpt_utils - +from st2common.util import concurrency +from st2common.util import action_db as action_utils LOG = logging.getLogger(__name__) @@ -40,10 +47,17 @@ queues.WORKFLOW_ACTION_EXECUTION_UPDATE_QUEUE, ] +WORKFLOW_ENGINE = "workflow_engine" +WORKFLOW_ENGINE_START_STOP_SEQ = "workflow_engine_start_stop_seq" + class WorkflowExecutionHandler(consumers.VariableMessageHandler): def __init__(self, connection, queues): super(WorkflowExecutionHandler, self).__init__(connection, queues) + self._active_messages = 0 + self._semaphore = Semaphore() + # This is required to ensure workflows stuck in pausing state after shutdown transition to paused state after engine startup. + self._delay = 30 def handle_workflow_execution_with_instrumentation(wf_ex_db): with metrics.CounterWithTimer(key="orquesta.workflow.executions"): @@ -78,6 +92,8 @@ def process(self, message): raise ValueError(msg) try: + with self._semaphore: + self._active_messages += 1 handler_function(message) except Exception as e: # If the exception is caused by DB connection error, then the following @@ -85,6 +101,60 @@ def process(self, message): # the database and fail the workflow execution gracefully. In this case, # the garbage collector will find and cancel these workflow executions. self.fail_workflow_execution(message, e) + finally: + with self._semaphore: + self._active_messages -= 1 + + def start(self, wait): + spawn_after(self._delay, self._resume_workflows_paused_during_shutdown) + super(WorkflowExecutionHandler, self).start(wait=wait) + + def shutdown(self): + super(WorkflowExecutionHandler, self).shutdown() + exit_timeout = cfg.CONF.workflow_engine.exit_still_active_check + sleep_delay = cfg.CONF.workflow_engine.still_active_check_interval + timeout = 0 + + while timeout < exit_timeout and self._active_messages > 0: + concurrency.sleep(sleep_delay) + timeout += sleep_delay + + coordinator = coordination.get_coordinator() + member_ids = [] + with coordinator.get_lock(WORKFLOW_ENGINE_START_STOP_SEQ): + try: + group_id = coordination.get_group_id(WORKFLOW_ENGINE) + member_ids = list(coordinator.get_members(group_id).get()) + except GroupNotCreated: + pass + + # Check if there are other WFEs in service registry + if cfg.CONF.coordination.service_registry and not member_ids: + ac_ex_dbs = self._get_running_workflows() + for ac_ex_db in ac_ex_dbs: + lv_ac = action_utils.get_liveaction_by_id(ac_ex_db.liveaction["id"]) + ac_svc.request_pause(lv_ac, WORKFLOW_ENGINE_START_STOP_SEQ) + + def _get_running_workflows(self): + query_filters = { + "runner__name": "orquesta", + "status": ac_const.LIVEACTION_STATUS_RUNNING, + } + return ex_db_access.ActionExecution.query(**query_filters) + + def _get_workflows_paused_during_shutdown(self): + query_filters = { + "status": ac_const.LIVEACTION_STATUS_PAUSED, + "context__paused_by": WORKFLOW_ENGINE_START_STOP_SEQ, + } + return lv_db_access.LiveAction.query(**query_filters) + + def _resume_workflows_paused_during_shutdown(self): + coordinator = coordination.get_coordinator() + with coordinator.get_lock(WORKFLOW_ENGINE_START_STOP_SEQ): + lv_ac_dbs = self._get_workflows_paused_during_shutdown() + for lv_ac_db in lv_ac_dbs: + ac_svc.request_resume(lv_ac_db, WORKFLOW_ENGINE_START_STOP_SEQ) def fail_workflow_execution(self, message, exception): # Prepare attributes based on message type. diff --git a/st2actions/tests/unit/test_workflow_engine.py b/st2actions/tests/unit/test_workflow_engine.py index 7c572e7ebb..a2090a2530 100644 --- a/st2actions/tests/unit/test_workflow_engine.py +++ b/st2actions/tests/unit/test_workflow_engine.py @@ -271,3 +271,237 @@ def test_process_error_handling_has_error(self, mock_get_lock): # Assert workflow execution is cleaned up and canceled. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_CANCELED) + + @mock.patch.object( + coordination_service.NoOpDriver, + "get_members", + mock.MagicMock(return_value=coordination_service.NoOpAsyncResult("")), + ) + def test_workflow_engine_shutdown(self): + cfg.CONF.set_override( + name="service_registry", override=True, group="coordination" + ) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) + lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) + + # Assert action execution is running. + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + workflow_engine = workflows.get_engine() + + eventlet.spawn(workflow_engine.shutdown) + + # Sleep for few seconds to ensure execution transitions to pausing. + eventlet.sleep(5) + + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_PAUSING) + + # Process task1. + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} + t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + + workflows.get_engine().process(t1_ac_ex_db) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + self.assertEqual( + t1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_PAUSED) + + workflow_engine = workflows.get_engine() + workflow_engine._delay = 0 + workflow_engine.start(False) + eventlet.sleep(workflow_engine._delay + 5) + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertTrue( + lv_ac_db.status + in [ + action_constants.LIVEACTION_STATUS_RESUMING, + action_constants.LIVEACTION_STATUS_RUNNING, + action_constants.LIVEACTION_STATUS_SUCCEEDED, + ] + ) + + @mock.patch.object( + coordination_service.NoOpDriver, + "get_members", + mock.MagicMock(return_value=coordination_service.NoOpAsyncResult("member-1")), + ) + def test_workflow_engine_shutdown_with_multiple_members(self): + cfg.CONF.set_override( + name="service_registry", override=True, group="coordination" + ) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) + lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) + + # Assert action execution is running. + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + workflow_engine = workflows.get_engine() + + eventlet.spawn(workflow_engine.shutdown) + + # Sleep for few seconds to ensure shutdown sequence completes. + eventlet.sleep(5) + + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + + # Process task1. + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} + t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + + workflows.get_engine().process(t1_ac_ex_db) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + self.assertEqual( + t1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + + def test_workflow_engine_shutdown_with_service_registry_disabled(self): + cfg.CONF.set_override( + name="service_registry", override=False, group="coordination" + ) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) + lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) + + # Assert action execution is running. + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + workflow_engine = workflows.get_engine() + + eventlet.spawn(workflow_engine.shutdown) + + # Sleep for few seconds to ensure shutdown sequence completes. + eventlet.sleep(5) + + # WFE doesn't pause the workflow, since service registry is disabled. + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + + @mock.patch.object( + coordination_service.NoOpDriver, + "get_lock", + mock.MagicMock(return_value=coordination_service.NoOpLock(name="noop")), + ) + def test_workflow_engine_shutdown_first_then_start(self): + cfg.CONF.set_override( + name="service_registry", override=True, group="coordination" + ) + cfg.CONF.set_override( + name="exit_still_active_check", override=0, group="workflow_engine" + ) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) + lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) + + # Assert action execution is running. + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + workflow_engine = workflows.get_engine() + + workflow_engine._delay = 5 + # Initiate shutdown first + eventlet.spawn(workflow_engine.shutdown) + eventlet.spawn_after(1, workflow_engine.start, True) + + # Sleep for few seconds to ensure shutdown sequence completes. + eventlet.sleep(2) + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + + # Shutdown routine acquires the lock first + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_PAUSING) + # Process task1 + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} + t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + + workflows.get_engine().process(t1_ac_ex_db) + # Startup sequence won't proceed until shutdown routine completes. + # Assuming shutdown sequence is complete, start up sequence will resume the workflow. + eventlet.sleep(workflow_engine._delay + 5) + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertTrue( + lv_ac_db.status + in [ + action_constants.LIVEACTION_STATUS_RESUMING, + action_constants.LIVEACTION_STATUS_RUNNING, + action_constants.LIVEACTION_STATUS_SUCCEEDED, + ] + ) + + @mock.patch.object( + coordination_service.NoOpDriver, + "get_lock", + mock.MagicMock(return_value=coordination_service.NoOpLock(name="noop")), + ) + def test_workflow_engine_start_first_then_shutdown(self): + cfg.CONF.set_override( + name="service_registry", override=True, group="coordination" + ) + cfg.CONF.set_override( + name="exit_still_active_check", override=0, group="workflow_engine" + ) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) + lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) + + # Assert action execution is running. + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + workflow_engine = workflows.get_engine() + + workflow_engine._delay = 0 + # Initiate start first + eventlet.spawn(workflow_engine.start, True) + eventlet.spawn_after(1, workflow_engine.shutdown) + + coordination_service.NoOpDriver.get_members = mock.MagicMock( + return_value=coordination_service.NoOpAsyncResult("member-1") + ) + + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + + # Startup routine acquires the lock first and shutdown routine sees a new member present in registry. + eventlet.sleep(workflow_engine._delay + 5) + lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) + self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) diff --git a/st2common/st2common/config.py b/st2common/st2common/config.py index 59cb9d01ab..c88955e4bb 100644 --- a/st2common/st2common/config.py +++ b/st2common/st2common/config.py @@ -797,6 +797,16 @@ def register_opts(ignore_errors=False): "orphaned and cancelled by the garbage collector. A value of zero means the " "feature is disabled. This is disabled by default.", ), + cfg.IntOpt( + "exit_still_active_check", + default=300, + help="How long to wait for process (in seconds) to exit after receiving shutdown signal.", + ), + cfg.IntOpt( + "still_active_check_interval", + default=2, + help="Time interval between subsequent queries to check executions handled by WFE.", + ), ] do_register_opts( diff --git a/st2common/st2common/services/coordination.py b/st2common/st2common/services/coordination.py index 15045d0eff..fc26e42e48 100644 --- a/st2common/st2common/services/coordination.py +++ b/st2common/st2common/services/coordination.py @@ -277,3 +277,11 @@ def get_member_id(): proc_info = system_info.get_process_info() member_id = six.b("%s_%d" % (proc_info["hostname"], proc_info["pid"])) return member_id + + +def get_group_id(service): + if not isinstance(service, six.binary_type): + group_id = service.encode("utf-8") + else: + group_id = service + return group_id diff --git a/st2common/st2common/transport/consumers.py b/st2common/st2common/transport/consumers.py index 44d867962d..6f4cca7c87 100644 --- a/st2common/st2common/transport/consumers.py +++ b/st2common/st2common/transport/consumers.py @@ -43,6 +43,7 @@ def __init__(self, connection, queues, handler): self._handler = handler def shutdown(self): + self.should_stop = True self._dispatcher.shutdown() def get_consumers(self, Consumer, channel):