diff --git a/castle/client.py b/castle/client.py index 2239bed..7855638 100644 --- a/castle/client.py +++ b/castle/client.py @@ -9,6 +9,7 @@ from castle.commands.track import CommandsTrack from castle.errors import InternalServerError, RequestError, ImpersonationFailed from castle.failover.prepare_response import FailoverPrepareResponse +from castle.failover.strategy import FailoverStrategy from castle.utils.timestamp import UtilsTimestamp as generate_timestamp @@ -43,7 +44,7 @@ def to_options(options=None): @staticmethod def failover_response_or_raise(options, exception): - if configuration.failover_strategy == 'throw': + if configuration.failover_strategy == FailoverStrategy.THROW.value: raise exception return FailoverPrepareResponse( options.get('user_id'), None, exception.__class__.__name__ diff --git a/castle/configuration.py b/castle/configuration.py index 8bcd1f2..5cfb4e9 100644 --- a/castle/configuration.py +++ b/castle/configuration.py @@ -1,6 +1,7 @@ from urllib.parse import urlparse, ParseResult from castle.errors import ConfigurationError from castle.headers.format import HeadersFormat +from castle.failover.strategy import FailoverStrategy DEFAULT_ALLOWLIST = [ "Accept", @@ -29,11 +30,11 @@ # API endpoint BASE_URL = 'https://api.castle.io/v1' -FAILOVER_STRATEGY = 'allow' +FAILOVER_STRATEGY = FailoverStrategy.ALLOW.value # 1000 milliseconds REQUEST_TIMEOUT = 1000 +FAILOVER_STRATEGIES = FailoverStrategy.list() # regexp of trusted proxies which is always appended to the trusted proxy list -FAILOVER_STRATEGIES = ['allow', 'deny', 'challenge', 'throw'] TRUSTED_PROXIES = [r""" \A127\.0\.0\.1\Z| \A(10|172\.(1[6-9]|2[0-9]|30|31)|192\.168)\.| diff --git a/castle/failover/strategy.py b/castle/failover/strategy.py new file mode 100644 index 0000000..15a3762 --- /dev/null +++ b/castle/failover/strategy.py @@ -0,0 +1,17 @@ +import enum + + +# handles failover strategy consts +class FailoverStrategy(enum.Enum): + # allow + ALLOW = 'allow' + # challenge + CHALLENGE = 'challenge' + # deny + DENY = 'deny' + # throw an error + THROW = 'throw' + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) diff --git a/castle/test/__init__.py b/castle/test/__init__.py index 87d3578..5f600f9 100644 --- a/castle/test/__init__.py +++ b/castle/test/__init__.py @@ -22,6 +22,7 @@ 'castle.test.core.process_response_test', 'castle.test.core.send_request_test', 'castle.test.failover.prepare_response_test', + 'castle.test.failover.strategy_test', 'castle.test.headers.extract_test', 'castle.test.headers.filter_test', 'castle.test.headers.format_test', @@ -34,6 +35,7 @@ 'castle.test.utils.timestamp_test', 'castle.test.validators.not_supported_test', 'castle.test.validators.present_test', + 'castle.test.verdict_test', ] # pylint: disable=redefined-builtin diff --git a/castle/test/client_test.py b/castle/test/client_test.py index a12a7d8..1038199 100644 --- a/castle/test/client_test.py +++ b/castle/test/client_test.py @@ -1,11 +1,13 @@ import json from collections import namedtuple import responses -from castle.test import mock, unittest +from castle.api_request import APIRequest from castle.client import Client from castle.configuration import configuration from castle.errors import ImpersonationFailed -from castle.api_request import APIRequest +from castle.failover.strategy import FailoverStrategy +from castle.test import mock, unittest +from castle.verdict import Verdict from castle.version import VERSION @@ -96,7 +98,7 @@ def test_identify_tracked_false(self): @responses.activate def test_authenticate_tracked_true(self): - response_text = {'action': 'allow', 'user_id': '1234'} + response_text = {'action': Verdict.ALLOW.value, 'user_id': '1234'} responses.add( responses.POST, 'https://api.castle.io/v1/authenticate', @@ -111,7 +113,7 @@ def test_authenticate_tracked_true(self): @responses.activate def test_authenticate_tracked_true_status_500(self): response_text = { - 'action': 'allow', + 'action': Verdict.ALLOW.value, 'user_id': '1234', 'failover': True, 'failover_reason': 'InternalServerError' @@ -128,7 +130,7 @@ def test_authenticate_tracked_true_status_500(self): def test_authenticate_tracked_false(self): response_text = { - 'action': 'allow', + 'action': Verdict.ALLOW.value, 'user_id': '1234', 'failover': True, 'failover_reason': 'Castle set to do not track.' @@ -213,7 +215,7 @@ def test_failover_strategy_not_throw(self): self.assertEqual( Client.failover_response_or_raise(options, Exception()), { - 'action': 'allow', + 'action': Verdict.ALLOW.value, 'user_id': '1234', 'failover': True, 'failover_reason': 'Exception' @@ -222,14 +224,14 @@ def test_failover_strategy_not_throw(self): def test_failover_strategy_throw(self): options = {'user_id': '1234'} - configuration.failover_strategy = 'throw' + configuration.failover_strategy = FailoverStrategy.THROW.value with self.assertRaises(Exception): Client.failover_response_or_raise(options, Exception()) - configuration.failover_strategy = 'allow' + configuration.failover_strategy = FailoverStrategy.ALLOW.value @responses.activate def test_timestamps_are_not_global(self): - response_text = {'action': 'allow', 'user_id': '1234'} + response_text = {'action': Verdict.ALLOW.value, 'user_id': '1234'} responses.add( responses.POST, 'https://api.castle.io/v1/authenticate', diff --git a/castle/test/configuration_test.py b/castle/test/configuration_test.py index 32f8415..1c2e396 100644 --- a/castle/test/configuration_test.py +++ b/castle/test/configuration_test.py @@ -2,6 +2,7 @@ from castle.test import unittest from castle.errors import ConfigurationError from castle.configuration import Configuration +from castle.failover.strategy import FailoverStrategy class ConfigurationTestCase(unittest.TestCase): @@ -14,7 +15,7 @@ def test_default_values(self): self.assertEqual(config.allowlisted, []) self.assertEqual(config.denylisted, []) self.assertEqual(config.request_timeout, 1000) - self.assertEqual(config.failover_strategy, 'allow') + self.assertEqual(config.failover_strategy, FailoverStrategy.ALLOW.value) self.assertEqual(config.ip_headers, []) self.assertEqual(config.trusted_proxies, []) @@ -74,8 +75,8 @@ def test_request_timeout_setter(self): def test_failover_strategy_setter_valid(self): config = Configuration() - config.failover_strategy = 'throw' - self.assertEqual(config.failover_strategy, 'throw') + config.failover_strategy = FailoverStrategy.THROW.value + self.assertEqual(config.failover_strategy, FailoverStrategy.THROW.value) def test_failover_strategy_setter_invalid(self): config = Configuration() diff --git a/castle/test/failover/strategy_test.py b/castle/test/failover/strategy_test.py new file mode 100644 index 0000000..8e82a08 --- /dev/null +++ b/castle/test/failover/strategy_test.py @@ -0,0 +1,19 @@ +from castle.test import unittest +from castle.failover.strategy import FailoverStrategy + + +class FailoverStrategyTestCase(unittest.TestCase): + def test_allow(self): + self.assertEqual(FailoverStrategy.ALLOW.value, 'allow') + + def test_challenge(self): + self.assertEqual(FailoverStrategy.CHALLENGE.value, 'challenge') + + def test_deny(self): + self.assertEqual(FailoverStrategy.DENY.value, 'deny') + + def test_throw(self): + self.assertEqual(FailoverStrategy.THROW.value, 'throw') + + def test_list(self): + self.assertEqual(FailoverStrategy.list(), ['allow', 'challenge', 'deny', 'throw']) diff --git a/castle/test/verdict_test.py b/castle/test/verdict_test.py new file mode 100644 index 0000000..3a0da8a --- /dev/null +++ b/castle/test/verdict_test.py @@ -0,0 +1,13 @@ +from castle.test import unittest +from castle.verdict import Verdict + + +class VerdictTestCase(unittest.TestCase): + def test_allow(self): + self.assertEqual(Verdict.ALLOW.value, 'allow') + + def test_challenge(self): + self.assertEqual(Verdict.CHALLENGE.value, 'challenge') + + def test_deny(self): + self.assertEqual(Verdict.DENY.value, 'deny') diff --git a/castle/verdict.py b/castle/verdict.py new file mode 100644 index 0000000..ffe3322 --- /dev/null +++ b/castle/verdict.py @@ -0,0 +1,11 @@ +import enum + + +# handles verdict consts +class Verdict(enum.Enum): + # allow + ALLOW = 'allow' + # challenge + CHALLENGE = 'challenge' + # deny + DENY = 'deny'