diff --git a/setup.cfg b/setup.cfg index 164be372..e04ca80b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,6 @@ exclude=tests/* test=pytest [tool:pytest] -ignore_glob=./splitio/_OLD/* addopts = --verbose --cov=splitio --cov-report xml python_classes=*Tests diff --git a/splitio/optional/loaders.py b/splitio/optional/loaders.py index b3c73d00..46c017b7 100644 --- a/splitio/optional/loaders.py +++ b/splitio/optional/loaders.py @@ -1,3 +1,4 @@ +import sys try: import asyncio import aiohttp @@ -10,3 +11,9 @@ def missing_asyncio_dependencies(*_, **__): ) aiohttp = missing_asyncio_dependencies asyncio = missing_asyncio_dependencies + +async def _anext(it): + return await it.__anext__() + +if sys.version_info.major < 3 or sys.version_info.minor < 10: + anext = _anext diff --git a/splitio/push/manager.py b/splitio/push/manager.py index 4f5112ae..05306441 100644 --- a/splitio/push/manager.py +++ b/splitio/push/manager.py @@ -290,6 +290,7 @@ def _handle_connection_end(self): if feedback is not None: self._feedback_loop.put(feedback) + class PushManagerAsync(PushManagerBase): # pylint:disable=too-many-instance-attributes """Push notifications susbsytem manager.""" @@ -354,10 +355,7 @@ async def start(self): return try: - self._token = await self._get_auth_token() - await self._trigger_connection_flow() - self._running_task = asyncio.get_running_loop().create_task(self._read_and_handle_events()) - self._token_task = asyncio.get_running_loop().create_task(self._token_refresh()) + self._running_task = asyncio.get_running_loop().create_task(self._trigger_connection_flow()) except Exception as e: _LOGGER.error("Exception renewing token authentication") _LOGGER.debug(str(e)) @@ -404,23 +402,21 @@ async def _event_handler(self, event): parsed.event_type) _LOGGER.debug(str(parsed), exc_info=True) - async def _token_refresh(self): + async def _token_refresh(self, current_token): """Refresh auth token.""" while self._running: try: - await self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.TOKEN_REFRESH, 1000 * self._token.exp, get_current_epoch_time_ms())) - await asyncio.sleep(self._get_time_period(self._token)) - _LOGGER.info("retriggering authentication flow.") + await asyncio.sleep(self._get_time_period(current_token)) + + # track proper metrics & stop everything await self._processor.update_workers_status(False) self._status_tracker.notify_sse_shutdown_expected() await self._sse_client.stop() self._running_task.cancel() self._running = False - self._token = await self._get_auth_token() - await self._telemetry_runtime_producer.record_token_refreshes() - await self._trigger_connection_flow() - self._running_task = asyncio.get_running_loop().create_task(self._read_and_handle_events()) + _LOGGER.info("retriggering authentication flow.") + self._running_task = asyncio.get_running_loop().create_task(self._trigger_connection_flow()) except Exception as e: _LOGGER.error("Exception renewing token authentication") _LOGGER.debug(str(e)) @@ -430,6 +426,9 @@ async def _get_auth_token(self): """Get new auth token""" try: token = await self._auth_api.authenticate() + await self._telemetry_runtime_producer.record_token_refreshes() + await self._telemetry_runtime_producer.record_streaming_event(StreamingEventTypes.TOKEN_REFRESH, 1000 * token.exp, get_current_epoch_time_ms()) + except APIException: _LOGGER.error('error performing sse auth request.') _LOGGER.debug('stack trace: ', exc_info=True) @@ -447,18 +446,21 @@ async def _trigger_connection_flow(self): """Authenticate and start a connection.""" self._status_tracker.reset() self._running = True - # awaiting first successful event - self._events_task = self._sse_client.start(self._token) - async def _read_and_handle_events(self): - first_event = await anext(self._events_task) + token = await self._get_auth_token() + events_source = self._sse_client.start(token) + first_event = await _anext(events_source) if first_event.event == SSE_EVENT_ERROR: self._running = False raise(Exception("could not start SSE session")) _LOGGER.debug("connected to streaming, scheduling next refresh") + self._token_task = asyncio.get_running_loop().create_task(self._token_refresh(token)) await self._handle_connection_ready() await self._telemetry_runtime_producer.record_streaming_event((StreamingEventTypes.CONNECTION_ESTABLISHED, 0, get_current_epoch_time_ms())) + await self._consume_events(events_source) + + async def _consume_events(self, events_task): try: while self._running: event = await anext(self._events_task) @@ -541,4 +543,4 @@ async def _handle_connection_end(self): """ feedback = self._status_tracker.handle_disconnect() if feedback is not None: - await self._feedback_loop.put(feedback) \ No newline at end of file + await self._feedback_loop.put(feedback) diff --git a/splitio/push/processor.py b/splitio/push/processor.py index c530c575..75216130 100644 --- a/splitio/push/processor.py +++ b/splitio/push/processor.py @@ -1,13 +1,28 @@ """Message processor & Notification manager keeper implementations.""" from queue import Queue +import abc from splitio.push.parser import UpdateType -from splitio.push.workers import SplitWorker -from splitio.push.workers import SegmentWorker +from splitio.push.workers import SplitWorker, SplitWorkerAsync, SegmentWorker, SegmentWorkerAsync +from splitio.optional.loaders import asyncio +class MessageProcessorBase(object, metaclass=abc.ABCMeta): + """Message processor template.""" -class MessageProcessor(object): + @abc.abstractmethod + def update_workers_status(self, enabled): + """Enable/Disable push update workers.""" + + @abc.abstractmethod + def handle(self, event): + """Handle incoming update event.""" + + @abc.abstractmethod + def shutdown(self): + """Stop splits & segments workers.""" + +class MessageProcessor(MessageProcessorBase): """Message processor class.""" def __init__(self, synchronizer): @@ -89,3 +104,87 @@ def shutdown(self): """Stop splits & segments workers.""" self._split_worker.stop() self._segments_worker.stop() + + +class MessageProcessorAsync(MessageProcessorBase): + """Message processor class.""" + + def __init__(self, synchronizer): + """ + Class constructor. + + :param synchronizer: synchronizer component + :type synchronizer: splitio.sync.synchronizer.Synchronizer + """ + self._split_queue = asyncio.Queue() + self._segments_queue = asyncio.Queue() + self._synchronizer = synchronizer + self._split_worker = SplitWorkerAsync(synchronizer.synchronize_splits, self._split_queue) + self._segments_worker = SegmentWorkerAsync(synchronizer.synchronize_segment, self._segments_queue) + self._handlers = { + UpdateType.SPLIT_UPDATE: self._handle_split_update, + UpdateType.SPLIT_KILL: self._handle_split_kill, + UpdateType.SEGMENT_UPDATE: self._handle_segment_change + } + + async def _handle_split_update(self, event): + """ + Handle incoming split update notification. + + :param event: Incoming split change event + :type event: splitio.push.parser.SplitChangeUpdate + """ + await self._split_queue.put(event) + + async def _handle_split_kill(self, event): + """ + Handle incoming split kill notification. + + :param event: Incoming split kill event + :type event: splitio.push.parser.SplitKillUpdate + """ + await self._synchronizer.kill_split(event.split_name, event.default_treatment, + event.change_number) + await self._split_queue.put(event) + + async def _handle_segment_change(self, event): + """ + Handle incoming segment update notification. + + :param event: Incoming segment change event + :type event: splitio.push.parser.Update + """ + await self._segments_queue.put(event) + + async def update_workers_status(self, enabled): + """ + Enable/Disable push update workers. + + :param enabled: if True, enable workers. If False, disable them. + :type enabled: bool + """ + if enabled: + self._split_worker.start() + self._segments_worker.start() + else: + await self._split_worker.stop() + await self._segments_worker.stop() + + async def handle(self, event): + """ + Handle incoming update event. + + :param event: incoming data update event. + :type event: splitio.push.BaseUpdate + """ + try: + handle = self._handlers[event.update_type] + except KeyError as exc: + raise Exception('no handler for notification type: %s' % event.update_type) from exc + + await handle(event) + + async def shutdown(self): + """Stop splits & segments workers.""" + await self._split_worker.stop() + await self._segments_worker.stop() diff --git a/splitio/push/splitsse.py b/splitio/push/splitsse.py index 0d416288..0adc86ef 100644 --- a/splitio/push/splitsse.py +++ b/splitio/push/splitsse.py @@ -2,16 +2,18 @@ import logging import threading from enum import Enum -from splitio.push.sse import SSEClient, SSE_EVENT_ERROR +import abc +import sys + +from splitio.push.sse import SSEClient, SSEClientAsync, SSE_EVENT_ERROR from splitio.util.threadutil import EventGroup from splitio.api import headers_from_metadata - +from splitio.optional.loaders import anext _LOGGER = logging.getLogger(__name__) - -class SplitSSEClient(object): # pylint: disable=too-many-instance-attributes - """Split streaming endpoint SSE client.""" +class SplitSSEClientBase(object, metaclass=abc.ABCMeta): + """Split streaming endpoint SSE base client.""" KEEPALIVE_TIMEOUT = 70 @@ -21,6 +23,50 @@ class _Status(Enum): ERRORED = 2 CONNECTED = 3 + @staticmethod + def _format_channels(channels): + """ + Format channels into a list from the raw object retrieved in the token. + + :param channels: object as extracted from the JWT capabilities. + :type channels: dict[str,list[str]] + + :returns: channels as a list of strings. + :rtype: list[str] + """ + regular = [k for (k, v) in channels.items() if v == ['subscribe']] + occupancy = ['[?occupancy=metrics.publishers]' + k + for (k, v) in channels.items() + if 'channel-metadata:publishers' in v] + return regular + occupancy + + def _build_url(self, token): + """ + Build the url to connect to and return it as a string. + + :param token: (parsed) JWT + :type token: splitio.models.token.Token + + :returns: true if the connection was successful. False otherwise. + :rtype: bool + """ + return '{base}/event-stream?v=1.1&accessToken={token}&channels={channels}'.format( + base=self._base_url, + token=token.token, + channels=','.join(self._format_channels(token.channels))) + + @abc.abstractmethod + def start(self, token): + """Open a connection to start listening for events.""" + + @abc.abstractmethod + def stop(self, blocking=False, timeout=None): + """Abort the ongoing connection.""" + + +class SplitSSEClient(SplitSSEClientBase): # pylint: disable=too-many-instance-attributes + """Split streaming endpoint SSE client.""" + def __init__(self, event_callback, sdk_metadata, first_event_callback=None, connection_closed_callback=None, client_key=None, base_url='https://streaming.split.io'): @@ -72,38 +118,6 @@ def _raw_event_handler(self, event): if event.data is not None: self._callback(event) - @staticmethod - def _format_channels(channels): - """ - Format channels into a list from the raw object retrieved in the token. - - :param channels: object as extracted from the JWT capabilities. - :type channels: dict[str,list[str]] - - :returns: channels as a list of strings. - :rtype: list[str] - """ - regular = [k for (k, v) in channels.items() if v == ['subscribe']] - occupancy = ['[?occupancy=metrics.publishers]' + k - for (k, v) in channels.items() - if 'channel-metadata:publishers' in v] - return regular + occupancy - - def _build_url(self, token): - """ - Build the url to connect to and return it as a string. - - :param token: (parsed) JWT - :type token: splitio.models.token.Token - - :returns: true if the connection was successful. False otherwise. - :rtype: bool - """ - return '{base}/event-stream?v=1.1&accessToken={token}&channels={channels}'.format( - base=self._base_url, - token=token.token, - channels=','.join(self._format_channels(token.channels))) - def start(self, token): """ Open a connection to start listening for events. @@ -148,3 +162,68 @@ def stop(self, blocking=False, timeout=None): self._client.shutdown() if blocking: self._sse_connection_closed.wait(timeout) + +class SplitSSEClientAsync(SplitSSEClientBase): # pylint: disable=too-many-instance-attributes + """Split streaming endpoint SSE client.""" + + def __init__(self, sdk_metadata, client_key=None, base_url='https://streaming.split.io'): + """ + Construct a split sse client. + + :param sdk_metadata: SDK version & machine name & IP. + :type sdk_metadata: splitio.client.util.SdkMetadata + + :param client_key: client key. + :type client_key: str + + :param base_url: scheme + :// + host + :type base_url: str + """ + self._base_url = base_url + self.status = SplitSSEClient._Status.IDLE + self._metadata = headers_from_metadata(sdk_metadata, client_key) + self._client = SSEClientAsync(timeout=self.KEEPALIVE_TIMEOUT) + + async def start(self, token): + """ + Open a connection to start listening for events. + + :param token: (parsed) JWT + :type token: splitio.models.token.Token + + :returns: yield events received from SSEClientAsync object + :rtype: SSEEvent + """ + if self.status != SplitSSEClient._Status.IDLE: + raise Exception('SseClient already started.') + + self.status = SplitSSEClient._Status.CONNECTING + url = self._build_url(token) + try: + sse_events_task = self._client.start(url, extra_headers=self._metadata) + first_event = await anext(sse_events_task) + if first_event.event == SSE_EVENT_ERROR: + await self.stop() + return + self.status = SplitSSEClient._Status.CONNECTED + _LOGGER.debug("Split SSE client started") + yield first_event + while self.status == SplitSSEClient._Status.CONNECTED: + event = await anext(sse_events_task) + if event.data is not None: + yield event + except StopAsyncIteration: + pass + except Exception: # pylint:disable=broad-except + self.status = SplitSSEClient._Status.IDLE + _LOGGER.debug('sse connection ended.') + _LOGGER.debug('stack trace: ', exc_info=True) + + async def stop(self, blocking=False, timeout=None): + """Abort the ongoing connection.""" + _LOGGER.debug("stopping SplitSSE Client") + if self.status == SplitSSEClient._Status.IDLE: + _LOGGER.warning('sse already closed. ignoring') + return + await self._client.shutdown() + self.status = SplitSSEClient._Status.IDLE diff --git a/splitio/push/sse.py b/splitio/push/sse.py index a6e2381c..c7941063 100644 --- a/splitio/push/sse.py +++ b/splitio/push/sse.py @@ -2,11 +2,9 @@ import logging import socket import abc -import urllib from collections import namedtuple from http.client import HTTPConnection, HTTPSConnection from urllib.parse import urlparse -import pytest from splitio.optional.loaders import asyncio, aiohttp from splitio.api.client import HttpClientException @@ -24,24 +22,6 @@ __ENDING_CHARS = set(['\n', '']) -def _get_request_parameters(url, extra_headers): - """ - Parse URL and headers - - :param url: url to connect to - :type url: str - - :param extra_headers: additional headers - :type extra_headers: dict[str, str] - - :returns: processed URL and Headers - :rtype: str, dict - """ - url = urlparse(url) - headers = _DEFAULT_HEADERS.copy() - headers.update(extra_headers if extra_headers is not None else {}) - return url, headers - class EventBuilder(object): """Event builder class.""" @@ -147,7 +127,7 @@ def start(self, url, extra_headers=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT) raise RuntimeError('Client already started.') self._shutdown_requested = False - url, headers = _get_request_parameters(url, extra_headers) + url, headers = urlparse(url), get_headers(extra_headers) self._conn = (HTTPSConnection(url.hostname, url.port, timeout=timeout) if url.scheme == 'https' else HTTPConnection(url.hostname, port=url.port, timeout=timeout)) @@ -171,56 +151,10 @@ def shutdown(self): class SSEClientAsync(SSEClientBase): """SSE Client implementation.""" - def __init__(self, callback): + def __init__(self, timeout=_DEFAULT_ASYNC_TIMEOUT): """ Construct an SSE client. - :param callback: function to call when an event is received - :type callback: callable - """ - self._conn = None - self._event_callback = callback - self._shutdown_requested = False - - async def _read_events(self, response): - """ - Read events from the supplied connection. - - :returns: True if the connection was ended by us. False if it was closed by the serve. - :rtype: bool - """ - try: - event_builder = EventBuilder() - while not self._shutdown_requested: - line = await response.readline() - if line is None or len(line) <= 0: # connection ended - break - elif line.startswith(b':'): # comment. Skip - _LOGGER.debug("skipping sse comment") - continue - elif line in _EVENT_SEPARATORS: - event = event_builder.build() - _LOGGER.debug("dispatching event: %s", event) - await self._event_callback(event) - event_builder = EventBuilder() - else: - event_builder.process_line(line) - except asyncio.CancelledError: - _LOGGER.debug("Cancellation request, proceeding to cancel.") - raise - except Exception: # pylint:disable=broad-except - _LOGGER.debug('sse connection ended.') - _LOGGER.debug('stack trace: ', exc_info=True) - finally: - await self._conn.close() - self._conn = None # clear so it can be started again - - return self._shutdown_requested - - async def start(self, url, extra_headers=None, timeout=_DEFAULT_ASYNC_TIMEOUT): # pylint:disable=protected-access - """ - Connect and start listening for events. - :param url: url to connect to :type url: str @@ -229,36 +163,63 @@ async def start(self, url, extra_headers=None, timeout=_DEFAULT_ASYNC_TIMEOUT): :param timeout: connection & read timeout :type timeout: float + """ + self._conn = None + self._shutdown_requested = False + self._timeout = timeout + self._session = None - :returns: True if the connection was ended by us. False if it was closed by the serve. - :rtype: bool + async def start(self, url, extra_headers=None): # pylint:disable=protected-access + """ + Connect and start listening for events. + + :returns: yield event when received + :rtype: SSEEvent """ _LOGGER.debug("Async SSEClient Started") if self._conn is not None: raise RuntimeError('Client already started.') self._shutdown_requested = False - url = urlparse(url) - headers = _DEFAULT_HEADERS.copy() - headers.update(extra_headers if extra_headers is not None else {}) - parsed_url = urllib.parse.urljoin(url[0] + "://" + url[1], url[2]) - params=url[4] try: self._conn = aiohttp.connector.TCPConnector() async with aiohttp.client.ClientSession( connector=self._conn, - headers=headers, - timeout=aiohttp.ClientTimeout(timeout) + headers=get_headers(extra_headers), + timeout=aiohttp.ClientTimeout(self._timeout) ) as self._session: - reader = await self._session.request( - "GET", - parsed_url, - params=params - ) - return await self._read_events(reader.content) + + self._reader = await self._session.request("GET", url) + try: + event_builder = EventBuilder() + while not self._shutdown_requested: + line = await self._reader.content.readline() + if line is None or len(line) <= 0: # connection ended + raise Exception('connection ended') + elif line.startswith(b':'): # comment. Skip + _LOGGER.debug("skipping sse comment") + continue + elif line in _EVENT_SEPARATORS: + _LOGGER.debug("dispatching event: %s", event_builder.build()) + yield event_builder.build() + event_builder = EventBuilder() + else: + event_builder.process_line(line) + except asyncio.CancelledError: + _LOGGER.debug("Cancellation request, proceeding to cancel.") + raise asyncio.CancelledError() + except Exception: # pylint:disable=broad-except + _LOGGER.debug('sse connection ended.') + _LOGGER.debug('stack trace: ', exc_info=True) + except asyncio.CancelledError: + pass except aiohttp.ClientError as exc: # pylint: disable=broad-except - _LOGGER.error(str(exc)) raise HttpClientException('http client is throwing exceptions') from exc + finally: + await self._conn.close() + self._conn = None # clear so it can be started again + _LOGGER.debug("Existing SSEClient") + return async def shutdown(self): """Shutdown the current connection.""" @@ -272,6 +233,26 @@ async def shutdown(self): return self._shutdown_requested = True - sock = self._session.connector._loop._ssock - sock.shutdown(socket.SHUT_RDWR) - await self._conn.close() \ No newline at end of file + if self._session is not None: + try: + await self._conn.close() + except asyncio.CancelledError: + pass + + +def get_headers(extra=None): + """ + Return default headers with added custom ones if specified. + + :param extra: additional headers + :type extra: dict[str, str] + + :returns: processed Headers + :rtype: dict + """ + headers = _DEFAULT_HEADERS.copy() + headers.update(extra if extra is not None else {}) + return headers + + + diff --git a/splitio/push/workers.py b/splitio/push/workers.py index a5e15fa0..7d035638 100644 --- a/splitio/push/workers.py +++ b/splitio/push/workers.py @@ -130,7 +130,7 @@ def start(self): self._running = True _LOGGER.debug('Starting Segment Worker') - asyncio.get_event_loop().create_task(self._run()) + asyncio.get_running_loop().create_task(self._run()) async def stop(self): """Stop worker.""" @@ -248,7 +248,7 @@ def start(self): self._running = True _LOGGER.debug('Starting Split Worker') - asyncio.get_event_loop().create_task(self._run()) + asyncio.get_running_loop().create_task(self._run()) async def stop(self): """Stop worker.""" diff --git a/splitio/storage/adapters/cache_trait.py b/splitio/storage/adapters/cache_trait.py index 399ee383..e73e7844 100644 --- a/splitio/storage/adapters/cache_trait.py +++ b/splitio/storage/adapters/cache_trait.py @@ -3,7 +3,7 @@ import threading import time from functools import update_wrapper - +from splitio.optional.loaders import asyncio DEFAULT_MAX_AGE = 5 DEFAULT_MAX_SIZE = 100 @@ -84,6 +84,42 @@ def get(self, *args, **kwargs): self._rollover() return node.value + async def get_key(self, key): + """ + Fetch an item from the cache, return None if does not exist + + :param key: User supplied key + :type key: str/frozenset + + :return: Cached/Fetched object + :rtype: object + """ + async with asyncio.Lock(): + node = self._data.get(key) + if node is not None: + if self._is_expired(node): + return None + if node is None: + return None + node = self._bubble_up(node) + return node.value + + async def add_key(self, key, value): + """ + Add an item from the cache. + + :param key: User supplied key + :type key: str/frozenset + + :param value: key value + :type value: str + """ + async with asyncio.Lock(): + node = LocalMemoryCache._Node(key, value, time.time(), None, None) + node = self._bubble_up(node) + self._data[key] = node + self._rollover() + def remove_expired(self): """Remove expired elements.""" with self._lock: @@ -189,4 +225,4 @@ def _decorator(user_function): wrapper = lambda *args, **kwargs: _cache.get(*args, **kwargs) # pylint: disable=unnecessary-lambda return update_wrapper(wrapper, user_function) - return _decorator + return _decorator \ No newline at end of file diff --git a/splitio/storage/adapters/redis.py b/splitio/storage/adapters/redis.py index de3026b3..72abb7cd 100644 --- a/splitio/storage/adapters/redis.py +++ b/splitio/storage/adapters/redis.py @@ -1,10 +1,11 @@ """Redis client wrapper with prefix support.""" from builtins import str - +import abc try: from redis import StrictRedis from redis.sentinel import Sentinel from redis.exceptions import RedisError + import redis.asyncio as aioredis except ImportError: def missing_redis_dependencies(*_, **__): """Fail if missing dependencies are used.""" @@ -12,7 +13,7 @@ def missing_redis_dependencies(*_, **__): 'Missing Redis support dependencies. ' 'Please use `pip install splitio_client[redis]` to install the sdk with redis support' ) - StrictRedis = Sentinel = missing_redis_dependencies + StrictRedis = Sentinel = aioredis = missing_redis_dependencies class RedisAdapterException(Exception): """Exception to be thrown when a redis command fails with an exception.""" @@ -102,8 +103,106 @@ def remove_prefix(self, k): "Cannot remove prefix correctly. Wrong type for key(s) provided" ) +class RedisAdapterBase(object, metaclass=abc.ABCMeta): + """Redis adapter template.""" + + @abc.abstractmethod + def keys(self, pattern): + """Mimic original redis keys.""" + + @abc.abstractmethod + def set(self, name, value, *args, **kwargs): + """Mimic original redis set.""" + + @abc.abstractmethod + def get(self, name): + """Mimic original redis get.""" + + @abc.abstractmethod + def setex(self, name, time, value): + """Mimic original redis setex.""" + + @abc.abstractmethod + def delete(self, *names): + """Mimic original redis delete.""" + + @abc.abstractmethod + def exists(self, name): + """Mimic original redis exists.""" + + @abc.abstractmethod + def lrange(self, key, start, end): + """Mimic original redis lrange.""" + + @abc.abstractmethod + def mget(self, names): + """Mimic original redis mget.""" + + @abc.abstractmethod + def smembers(self, name): + """Mimic original redis smembers.""" + + @abc.abstractmethod + def sadd(self, name, *values): + """Mimic original redis sadd.""" + + @abc.abstractmethod + def srem(self, name, *values): + """Mimic original redis srem.""" + + @abc.abstractmethod + def sismember(self, name, value): + """Mimic original redis sismember.""" + + @abc.abstractmethod + def eval(self, script, number_of_keys, *keys): + """Mimic original redis eval.""" + + @abc.abstractmethod + def hset(self, name, key, value): + """Mimic original redis hset.""" -class RedisAdapter(object): # pylint: disable=too-many-public-methods + @abc.abstractmethod + def hget(self, name, key): + """Mimic original redis hget.""" + + @abc.abstractmethod + def hincrby(self, name, key, amount=1): + """Mimic original redis hincrby.""" + + @abc.abstractmethod + def incr(self, name, amount=1): + """Mimic original redis incr.""" + + @abc.abstractmethod + def getset(self, name, value): + """Mimic original redis getset.""" + + @abc.abstractmethod + def rpush(self, key, *values): + """Mimic original redis rpush.""" + + @abc.abstractmethod + def expire(self, key, value): + """Mimic original redis expire.""" + + @abc.abstractmethod + def rpop(self, key): + """Mimic original redis rpop.""" + + @abc.abstractmethod + def ttl(self, key): + """Mimic original redis ttl.""" + + @abc.abstractmethod + def lpop(self, key): + """Mimic original redis lpop.""" + + @abc.abstractmethod + def pipeline(self): + """Mimic original redis pipeline.""" + +class RedisAdapter(RedisAdapterBase): # pylint: disable=too-many-public-methods """ Instance decorator for Redis clients such as StrictRedis. @@ -303,7 +402,230 @@ def pipeline(self): except RedisError as exc: raise RedisAdapterException('Error executing ttl operation') from exc -class RedisPipelineAdapter(object): + +class RedisAdapterAsync(RedisAdapterBase): # pylint: disable=too-many-public-methods + """ + Instance decorator for asyncio Redis clients such as StrictRedis. + + Adds an extra layer handling addition/removal of user prefix when handling + keys + """ + def __init__(self, decorated, prefix=None): + """ + Store the user prefix and the redis client instance. + + :param decorated: Instance of redis cache client to decorate. + :param prefix: User prefix to add. + """ + self._decorated = decorated + self._prefix_helper = PrefixHelper(prefix) + + # Below starts a list of methods that implement the interface of a standard + # redis client. + + async def keys(self, pattern): + """Mimic original redis function but using user custom prefix.""" + try: + return [ + key + for key in self._prefix_helper.remove_prefix(await self._decorated.keys(self._prefix_helper.add_prefix(pattern))) + ] + except RedisError as exc: + raise RedisAdapterException('Failed to execute keys operation') from exc + + async def set(self, name, value, *args, **kwargs): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.set( + self._prefix_helper.add_prefix(name), value, *args, **kwargs + ) + except RedisError as exc: + raise RedisAdapterException('Failed to execute set operation') from exc + + async def get(self, name): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.get(self._prefix_helper.add_prefix(name)) + except RedisError as exc: + raise RedisAdapterException('Error executing get operation') from exc + + async def setex(self, name, time, value): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.setex(self._prefix_helper.add_prefix(name), time, value) + except RedisError as exc: + raise RedisAdapterException('Error executing setex operation') from exc + + async def delete(self, *names): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.delete(*self._prefix_helper.add_prefix(list(names))) + except RedisError as exc: + raise RedisAdapterException('Error executing delete operation') from exc + + async def exists(self, name): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.exists(self._prefix_helper.add_prefix(name)) + except RedisError as exc: + raise RedisAdapterException('Error executing exists operation') from exc + + async def lrange(self, key, start, end): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.lrange(self._prefix_helper.add_prefix(key), start, end) + except RedisError as exc: + raise RedisAdapterException('Error executing exists operation') from exc + + async def mget(self, names): + """Mimic original redis function but using user custom prefix.""" + try: + return [ + item + for item in await self._decorated.mget(self._prefix_helper.add_prefix(names)) + ] + except RedisError as exc: + raise RedisAdapterException('Error executing mget operation') from exc + + async def smembers(self, name): + """Mimic original redis function but using user custom prefix.""" + try: + return [ + item + for item in await self._decorated.smembers(self._prefix_helper.add_prefix(name)) + ] + except RedisError as exc: + raise RedisAdapterException('Error executing smembers operation') from exc + + async def sadd(self, name, *values): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.sadd(self._prefix_helper.add_prefix(name), *values) + except RedisError as exc: + raise RedisAdapterException('Error executing sadd operation') from exc + + async def srem(self, name, *values): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.srem(self._prefix_helper.add_prefix(name), *values) + except RedisError as exc: + raise RedisAdapterException('Error executing srem operation') from exc + + async def sismember(self, name, value): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.sismember(self._prefix_helper.add_prefix(name), value) + except RedisError as exc: + raise RedisAdapterException('Error executing sismember operation') from exc + + async def eval(self, script, number_of_keys, *keys): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.eval(script, number_of_keys, *self._prefix_helper.add_prefix(list(keys))) + except RedisError as exc: + raise RedisAdapterException('Error executing eval operation') from exc + + async def hset(self, name, key, value): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.hset(self._prefix_helper.add_prefix(name), key, value) + except RedisError as exc: + raise RedisAdapterException('Error executing hset operation') from exc + + async def hget(self, name, key): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.hget(self._prefix_helper.add_prefix(name), key) + except RedisError as exc: + raise RedisAdapterException('Error executing hget operation') from exc + + async def hincrby(self, name, key, amount=1): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.hincrby(self._prefix_helper.add_prefix(name), key, amount) + except RedisError as exc: + raise RedisAdapterException('Error executing hincrby operation') from exc + + async def incr(self, name, amount=1): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.incr(self._prefix_helper.add_prefix(name), amount) + except RedisError as exc: + raise RedisAdapterException('Error executing incr operation') from exc + + async def getset(self, name, value): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.getset(self._prefix_helper.add_prefix(name), value) + except RedisError as exc: + raise RedisAdapterException('Error executing getset operation') from exc + + async def rpush(self, key, *values): + """Mimic original redis function but using user custom prefix.""" + try: + async with self._decorated.client() as conn: + return await conn.rpush(self._prefix_helper.add_prefix(key), *values) + except RedisError as exc: + raise RedisAdapterException('Error executing rpush operation') from exc + + async def expire(self, key, value): + """Mimic original redis function but using user custom prefix.""" + try: + async with self._decorated.client() as conn: + return await conn.expire(self._prefix_helper.add_prefix(key), value) + except RedisError as exc: + raise RedisAdapterException('Error executing expire operation') from exc + + async def rpop(self, key): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.rpop(self._prefix_helper.add_prefix(key)) + except RedisError as exc: + raise RedisAdapterException('Error executing rpop operation') from exc + + async def ttl(self, key): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.ttl(self._prefix_helper.add_prefix(key)) + except RedisError as exc: + raise RedisAdapterException('Error executing ttl operation') from exc + + async def lpop(self, key): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._decorated.lpop(self._prefix_helper.add_prefix(key)) + except RedisError as exc: + raise RedisAdapterException('Error executing lpop operation') from exc + + def pipeline(self): + """Mimic original redis pipeline.""" + try: + return RedisPipelineAdapterAsync(self._decorated, self._prefix_helper) + except RedisError as exc: + raise RedisAdapterException('Error executing ttl operation') from exc + +class RedisPipelineAdapterBase(object, metaclass=abc.ABCMeta): + """ + Template decorator for Redis Pipeline. + """ + @abc.abstractmethod + def rpush(self, key, *values): + """Mimic original redis function but using user custom prefix.""" + + @abc.abstractmethod + def incr(self, name, amount=1): + """Mimic original redis function but using user custom prefix.""" + + @abc.abstractmethod + def hincrby(self, name, key, amount=1): + """Mimic original redis function but using user custom prefix.""" + + @abc.abstractmethod + def execute(self): + """Mimic original redis execute.""" + + +class RedisPipelineAdapter(RedisPipelineAdapterBase): """ Instance decorator for Redis Pipeline. @@ -340,6 +662,43 @@ def execute(self): raise RedisAdapterException('Error executing pipeline operation') from exc +class RedisPipelineAdapterAsync(RedisPipelineAdapterBase): + """ + Instance decorator for Asyncio Redis Pipeline. + + Adds an extra layer handling addition/removal of user prefix when handling + keys + """ + def __init__(self, decorated, prefix_helper): + """ + Store the user prefix and the redis client instance. + + :param decorated: Instance of redis cache client to decorate. + :param _prefix_helper: PrefixHelper utility + """ + self._prefix_helper = prefix_helper + self._pipe = decorated.pipeline() + + async def rpush(self, key, *values): + """Mimic original redis function but using user custom prefix.""" + await self._pipe.rpush(self._prefix_helper.add_prefix(key), *values) + + async def incr(self, name, amount=1): + """Mimic original redis function but using user custom prefix.""" + await self._pipe.incr(self._prefix_helper.add_prefix(name), amount) + + async def hincrby(self, name, key, amount=1): + """Mimic original redis function but using user custom prefix.""" + await self._pipe.hincrby(self._prefix_helper.add_prefix(name), key, amount) + + async def execute(self): + """Mimic original redis function but using user custom prefix.""" + try: + return await self._pipe.execute() + except RedisError as exc: + raise RedisAdapterException('Error executing pipeline operation') from exc + + def _build_default_client(config): # pylint: disable=too-many-locals """ Build a redis adapter. @@ -398,6 +757,63 @@ def _build_default_client(config): # pylint: disable=too-many-locals ) return RedisAdapter(redis, prefix=prefix) +async def _build_default_client_async(config): # pylint: disable=too-many-locals + """ + Build a redis asyncio adapter. + + :param config: Redis configuration properties + :type config: dict + + :return: A wrapped Redis object + :rtype: splitio.storage.adapters.redis.RedisAdapterAsync + """ + host = config.get('redisHost', 'localhost') + port = config.get('redisPort', 6379) + database = config.get('redisDb', 0) + password = config.get('redisPassword', None) + socket_timeout = config.get('redisSocketTimeout', None) + socket_connect_timeout = config.get('redisSocketConnectTimeout', None) + socket_keepalive = config.get('redisSocketKeepalive', None) + socket_keepalive_options = config.get('redisSocketKeepaliveOptions', None) + connection_pool = config.get('redisConnectionPool', None) + unix_socket_path = config.get('redisUnixSocketPath', None) + encoding = config.get('redisEncoding', 'utf-8') + encoding_errors = config.get('redisEncodingErrors', 'strict') + errors = config.get('redisErrors', None) + decode_responses = config.get('redisDecodeResponses', True) + retry_on_timeout = config.get('redisRetryOnTimeout', False) + ssl = config.get('redisSsl', False) + ssl_keyfile = config.get('redisSslKeyfile', None) + ssl_certfile = config.get('redisSslCertfile', None) + ssl_cert_reqs = config.get('redisSslCertReqs', None) + ssl_ca_certs = config.get('redisSslCaCerts', None) + max_connections = config.get('redisMaxConnections', None) + prefix = config.get('redisPrefix') + + redis = await aioredis.from_url( + "redis://" + host + ":" + str(port), + db=database, + password=password, + timeout=socket_timeout, + socket_connect_timeout=socket_connect_timeout, + socket_keepalive=socket_keepalive, + socket_keepalive_options=socket_keepalive_options, + connection_pool=connection_pool, + unix_socket_path=unix_socket_path, + encoding=encoding, + encoding_errors=encoding_errors, + errors=errors, + decode_responses=decode_responses, + retry_on_timeout=retry_on_timeout, + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + max_connections=max_connections + ) + return RedisAdapterAsync(redis, prefix=prefix) + def _build_sentinel_client(config): # pylint: disable=too-many-locals """ @@ -464,6 +880,18 @@ def _build_sentinel_client(config): # pylint: disable=too-many-locals return RedisAdapter(redis, prefix=prefix) +async def build_async(config): + """ + Build a async redis storage according to the configuration received. + + :param config: SDK Configuration parameters with redis properties. + :type config: dict. + + :return: A redis async client + :rtype: splitio.storage.adapters.redis.RedisAdapterAsync + """ + return await _build_default_client_async(config) + def build(config): """ Build a redis storage according to the configuration received. diff --git a/splitio/storage/redis.py b/splitio/storage/redis.py index d2aa2788..abf596b2 100644 --- a/splitio/storage/redis.py +++ b/splitio/storage/redis.py @@ -5,36 +5,24 @@ from splitio.models.impressions import Impression from splitio.models import splits, segments -from splitio.models.telemetry import MethodExceptions, MethodLatencies, TelemetryConfig, get_latency_bucket_index +from splitio.models.telemetry import TelemetryConfig, get_latency_bucket_index, TelemetryConfigAsync from splitio.storage import SplitStorage, SegmentStorage, ImpressionStorage, EventStorage, \ ImpressionPipelinedStorage, TelemetryStorage from splitio.storage.adapters.redis import RedisAdapterException from splitio.storage.adapters.cache_trait import decorate as add_cache, DEFAULT_MAX_AGE - +from splitio.optional.loaders import asyncio +from splitio.storage.adapters.cache_trait import LocalMemoryCache _LOGGER = logging.getLogger(__name__) MAX_TAGS = 10 -class RedisSplitStorage(SplitStorage): - """Redis-based storage for splits.""" +class RedisSplitStorageBase(SplitStorage): + """Redis-based storage template for splits.""" _SPLIT_KEY = 'SPLITIO.split.{split_name}' _SPLIT_TILL_KEY = 'SPLITIO.splits.till' _TRAFFIC_TYPE_KEY = 'SPLITIO.trafficType.{traffic_type_name}' - def __init__(self, redis_client, enable_caching=False, max_age=DEFAULT_MAX_AGE): - """ - Class constructor. - - :param redis_client: Redis client or compliant interface. - :type redis_client: splitio.storage.adapters.redis.RedisAdapter - """ - self._redis = redis_client - if enable_caching: - self.get = add_cache(lambda *p, **_: p[0], max_age)(self.get) - self.is_valid_traffic_type = add_cache(lambda *p, **_: p[0], max_age)(self.is_valid_traffic_type) # pylint: disable=line-too-long - self.fetch_many = add_cache(lambda *p, **_: frozenset(p[0]), max_age)(self.fetch_many) - def _get_key(self, split_name): """ Use the provided split_name to build the appropriate redis key. @@ -59,6 +47,98 @@ def _get_traffic_type_key(self, traffic_type_name): """ return self._TRAFFIC_TYPE_KEY.format(traffic_type_name=traffic_type_name) + def put(self, split): + """ + Store a split. + + :param split: Split object to store + :type split_name: splitio.models.splits.Split + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def remove(self, split_name): + """ + Remove a split from storage. + + :param split_name: Name of the feature to remove. + :type split_name: str + + :return: True if the split was found and removed. False otherwise. + :rtype: bool + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def set_change_number(self, new_change_number): + """ + Set the latest change number. + + :param new_change_number: New change number. + :type new_change_number: int + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def get_splits_count(self): + """ + Return splits count. + + :rtype: int + """ + return 0 + + def kill_locally(self, split_name, default_treatment, change_number): + """ + Local kill for split + + :param split_name: name of the split to perform kill + :type split_name: str + :param default_treatment: name of the default treatment to return + :type default_treatment: str + :param change_number: change_number + :type change_number: int + """ + raise NotImplementedError('Not supported for redis.') + + def get(self, split_name): # pylint: disable=method-hidden + """Retrieve a split.""" + pass + + def fetch_many(self, split_names): + """Retrieve splits.""" + pass + + def is_valid_traffic_type(self, traffic_type_name): # pylint: disable=method-hidden + """Return whether the traffic type exists in at least one split in cache.""" + pass + + def get_change_number(self): + """Retrieve latest split change number.""" + pass + + def get_split_names(self): + """Retrieve a list of all split names.""" + pass + + def get_all_splits(self): + """Return all the splits in cache.""" + pass + + +class RedisSplitStorage(RedisSplitStorageBase): + """Redis-based storage for splits.""" + + def __init__(self, redis_client, enable_caching=False, max_age=DEFAULT_MAX_AGE): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + """ + self._redis = redis_client + if enable_caching: + self.get = add_cache(lambda *p, **_: p[0], max_age)(self.get) + self.is_valid_traffic_type = add_cache(lambda *p, **_: p[0], max_age)(self.is_valid_traffic_type) # pylint: disable=line-too-long + self.fetch_many = add_cache(lambda *p, **_: frozenset(p[0]), max_age)(self.fetch_many) + def get(self, split_name): # pylint: disable=method-hidden """ Retrieve a split. @@ -128,27 +208,6 @@ def is_valid_traffic_type(self, traffic_type_name): # pylint: disable=method-hi _LOGGER.debug('Error: ', exc_info=True) return False - def put(self, split): - """ - Store a split. - - :param split: Split object to store - :type split_name: splitio.models.splits.Split - """ - raise NotImplementedError('Only redis-consumer mode is supported.') - - def remove(self, split_name): - """ - Remove a split from storage. - - :param split_name: Name of the feature to remove. - :type split_name: str - - :return: True if the split was found and removed. False otherwise. - :rtype: bool - """ - raise NotImplementedError('Only redis-consumer mode is supported.') - def get_change_number(self): """ Retrieve latest split change number. @@ -164,15 +223,6 @@ def get_change_number(self): _LOGGER.debug('Error: ', exc_info=True) return None - def set_change_number(self, new_change_number): - """ - Set the latest change number. - - :param new_change_number: New change number. - :type new_change_number: int - """ - raise NotImplementedError('Only redis-consumer mode is supported.') - def get_split_names(self): """ Retrieve a list of all split names. @@ -189,14 +239,6 @@ def get_split_names(self): _LOGGER.debug('Error: ', exc_info=True) return [] - def get_splits_count(self): - """ - Return splits count. - - :rtype: int - """ - return 0 - def get_all_splits(self): """ Return all the splits in cache. @@ -220,18 +262,153 @@ def get_all_splits(self): _LOGGER.debug('Error: ', exc_info=True) return to_return - def kill_locally(self, split_name, default_treatment, change_number): + +class RedisSplitStorageAsync(RedisSplitStorage): + """Async Redis-based storage for splits.""" + + def __init__(self, redis_client, enable_caching=False, max_age=DEFAULT_MAX_AGE): """ - Local kill for split + Class constructor. - :param split_name: name of the split to perform kill + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + """ + self._redis = redis_client + self._enable_caching = enable_caching + if enable_caching: + self._cache = LocalMemoryCache(None, None, max_age) + + async def get(self, split_name): # pylint: disable=method-hidden + """ + Retrieve a split. + + :param split_name: Name of the feature to fetch. :type split_name: str - :param default_treatment: name of the default treatment to return - :type default_treatment: str - :param change_number: change_number - :type change_number: int + + :return: A split object parsed from redis if the key exists. None otherwise + :rtype: splitio.models.splits.Split """ - raise NotImplementedError('Not supported for redis.') + try: + if self._enable_caching and await self._cache.get_key(split_name) is not None: + raw = await self._cache.get_key(split_name) + else: + raw = await self._redis.get(self._get_key(split_name)) + if self._enable_caching: + await self._cache.add_key(split_name, raw) + _LOGGER.debug("Fetchting Split [%s] from redis" % split_name) + _LOGGER.debug(raw) + return splits.from_raw(json.loads(raw)) if raw is not None else None + except RedisAdapterException: + _LOGGER.error('Error fetching split from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def fetch_many(self, split_names): + """ + Retrieve splits. + + :param split_names: Names of the features to fetch. + :type split_name: list(str) + + :return: A dict with split objects parsed from redis. + :rtype: dict(split_name, splitio.models.splits.Split) + """ + to_return = dict() + try: + if self._enable_caching and await self._cache.get_key(frozenset(split_names)) is not None: + raw_splits = await self._cache.get_key(frozenset(split_names)) + else: + keys = [self._get_key(split_name) for split_name in split_names] + raw_splits = await self._redis.mget(keys) + if self._enable_caching: + await self._cache.add_key(frozenset(split_names), raw_splits) + for i in range(len(split_names)): + split = None + try: + split = splits.from_raw(json.loads(raw_splits[i])) + except (ValueError, TypeError): + _LOGGER.error('Could not parse split.') + _LOGGER.debug("Raw split that failed parsing attempt: %s", raw_splits[i]) + to_return[split_names[i]] = split + except RedisAdapterException: + _LOGGER.error('Error fetching splits from storage') + _LOGGER.debug('Error: ', exc_info=True) + return to_return + + async def is_valid_traffic_type(self, traffic_type_name): # pylint: disable=method-hidden + """ + Return whether the traffic type exists in at least one split in cache. + + :param traffic_type_name: Traffic type to validate. + :type traffic_type_name: str + + :return: True if the traffic type is valid. False otherwise. + :rtype: bool + """ + try: + if self._enable_caching and await self._cache.get_key(traffic_type_name) is not None: + raw = await self._cache.get_key(traffic_type_name) + else: + raw = await self._redis.get(self._get_traffic_type_key(traffic_type_name)) + if self._enable_caching: + await self._cache.add_key(traffic_type_name, raw) + count = json.loads(raw) if raw else 0 + return count > 0 + except RedisAdapterException: + _LOGGER.error('Error fetching split from storage') + _LOGGER.debug('Error: ', exc_info=True) + return False + + async def get_change_number(self): + """ + Retrieve latest split change number. + + :rtype: int + """ + try: + stored_value = await self._redis.get(self._SPLIT_TILL_KEY) + return json.loads(stored_value) if stored_value is not None else None + except RedisAdapterException: + _LOGGER.error('Error fetching split change number from storage') + _LOGGER.debug('Error: ', exc_info=True) + return None + + async def get_split_names(self): + """ + Retrieve a list of all split names. + + :return: List of split names. + :rtype: list(str) + """ + try: + keys = await self._redis.keys(self._get_key('*')) + return [key.replace(self._get_key(''), '') for key in keys] + except RedisAdapterException: + _LOGGER.error('Error fetching split names from storage') + _LOGGER.debug('Error: ', exc_info=True) + return [] + + async def get_all_splits(self): + """ + Return all the splits in cache. + + :return: List of all splits in cache. + :rtype: list(splitio.models.splits.Split) + """ + keys = await self._redis.keys(self._get_key('*')) + to_return = [] + try: + raw_splits = await self._redis.mget(keys) + for raw in raw_splits: + try: + to_return.append(splits.from_raw(json.loads(raw))) + except (ValueError, TypeError): + _LOGGER.error('Could not parse split. Skipping') + _LOGGER.debug("Raw split that failed parsing attempt: %s", raw) + except RedisAdapterException: + _LOGGER.error('Error fetching all splits from storage') + _LOGGER.debug('Error: ', exc_info=True) + return to_return class RedisSegmentStorage(SegmentStorage): @@ -385,24 +562,12 @@ def get_segments_keys_count(self): """ return 0 -class RedisImpressionsStorage(ImpressionStorage, ImpressionPipelinedStorage): - """Redis based event storage class.""" +class RedisImpressionsStorageBase(ImpressionStorage, ImpressionPipelinedStorage): + """Redis based event storage base class.""" IMPRESSIONS_QUEUE_KEY = 'SPLITIO.impressions' IMPRESSIONS_KEY_DEFAULT_TTL = 3600 - def __init__(self, redis_client, sdk_metadata): - """ - Class constructor. - - :param redis_client: Redis client or compliant interface. - :type redis_client: splitio.storage.adapters.redis.RedisAdapter - :param sdk_metadata: SDK & Machine information. - :type sdk_metadata: splitio.client.util.SdkMetadata - """ - self._redis = redis_client - self._sdk_metadata = sdk_metadata - def _wrap_impressions(self, impressions): """ Wrap impressions to be stored in redis @@ -444,8 +609,7 @@ def expire_key(self, total_keys, inserted): :param inserted: added keys. :type inserted: int """ - if total_keys == inserted: - self._redis.expire(self.IMPRESSIONS_QUEUE_KEY, self.IMPRESSIONS_KEY_DEFAULT_TTL) + pass def add_impressions_to_pipe(self, impressions, pipe): """ @@ -461,6 +625,61 @@ def add_impressions_to_pipe(self, impressions, pipe): _LOGGER.debug(bulk_impressions) pipe.rpush(self.IMPRESSIONS_QUEUE_KEY, *bulk_impressions) + def put(self, impressions): + """ + Add an impression to the redis storage. + + :param impressions: Impression to add to the queue. + :type impressions: splitio.models.impressions.Impression + + :return: Whether the impression has been added or not. + :rtype: bool + """ + pass + + def pop_many(self, count): + """ + Pop the oldest N events from storage. + + :param count: Number of events to pop. + :type count: int + """ + raise NotImplementedError('Only redis-consumer mode is supported.') + + def clear(self): + """ + Clear data. + """ + raise NotImplementedError('Not supported for redis.') + + +class RedisImpressionsStorage(RedisImpressionsStorageBase): + """Redis based event storage class.""" + + def __init__(self, redis_client, sdk_metadata): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + self._redis = redis_client + self._sdk_metadata = sdk_metadata + + def expire_key(self, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + self._redis.expire(self.IMPRESSIONS_QUEUE_KEY, self.IMPRESSIONS_KEY_DEFAULT_TTL) + def put(self, impressions): """ Add an impression to the redis storage. @@ -483,20 +702,55 @@ def put(self, impressions): _LOGGER.error('Error: ', exc_info=True) return False - def pop_many(self, count): + +class RedisImpressionsStorageAsync(RedisImpressionsStorageBase): + """Redis based event storage async class.""" + + def __init__(self, redis_client, sdk_metadata): """ - Pop the oldest N events from storage. + Class constructor. - :param count: Number of events to pop. - :type count: int + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata """ - raise NotImplementedError('Only redis-consumer mode is supported.') + self._redis = redis_client + self._sdk_metadata = sdk_metadata - def clear(self): + async def expire_key(self, total_keys, inserted): """ - Clear data. + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int """ - raise NotImplementedError('Not supported for redis.') + if total_keys == inserted: + await self._redis.expire(self.IMPRESSIONS_QUEUE_KEY, self.IMPRESSIONS_KEY_DEFAULT_TTL) + + async def put(self, impressions): + """ + Add an impression to the redis storage. + + :param impressions: Impression to add to the queue. + :type impressions: splitio.models.impressions.Impression + + :return: Whether the impression has been added or not. + :rtype: bool + """ + bulk_impressions = self._wrap_impressions(impressions) + try: + _LOGGER.debug("Adding Impressions to redis key %s" % (self.IMPRESSIONS_QUEUE_KEY)) + _LOGGER.debug(bulk_impressions) + inserted = await self._redis.rpush(self.IMPRESSIONS_QUEUE_KEY, *bulk_impressions) + await self.expire_key(inserted, len(bulk_impressions)) + return True + except RedisAdapterException: + _LOGGER.error('Something went wrong when trying to add impression to redis') + _LOGGER.error('Error: ', exc_info=True) + return False class RedisEventsStorage(EventStorage): @@ -600,7 +854,7 @@ def expire_keys(self, total_keys, inserted): if total_keys == inserted: self._redis.expire(self._EVENTS_KEY_TEMPLATE, self._EVENTS_KEY_DEFAULT_TTL) -class RedisTelemetryStorage(TelemetryStorage): +class RedisTelemetryStorageBase(TelemetryStorage): """Redis based telemetry storage class.""" _TELEMETRY_CONFIG_KEY = 'SPLITIO.telemetry.init' @@ -608,33 +862,13 @@ class RedisTelemetryStorage(TelemetryStorage): _TELEMETRY_EXCEPTIONS_KEY = 'SPLITIO.telemetry.exceptions' _TELEMETRY_KEY_DEFAULT_TTL = 3600 - def __init__(self, redis_client, sdk_metadata): - """ - Class constructor. - - :param redis_client: Redis client or compliant interface. - :type redis_client: splitio.storage.adapters.redis.RedisAdapter - :param sdk_metadata: SDK & Machine information. - :type sdk_metadata: splitio.client.util.SdkMetadata - """ - self._lock = threading.RLock() - self._reset_config_tags() - self._redis_client = redis_client - self._sdk_metadata = sdk_metadata - self._method_latencies = MethodLatencies() - self._method_exceptions = MethodExceptions() - self._tel_config = TelemetryConfig() - self._make_pipe = redis_client.pipeline - def _reset_config_tags(self): - with self._lock: - self._config_tags = [] + """Reset all config tags""" + pass def add_config_tag(self, tag): """Record tag string.""" - with self._lock: - if len(self._config_tags) < MAX_TAGS: - self._config_tags.append(tag) + pass def record_config(self, config, extra_config): """ @@ -643,35 +877,29 @@ def record_config(self, config, extra_config): :param congif: factory configuration parameters :type config: splitio.client.config """ - self._tel_config.record_config(config, extra_config) + pass def pop_config_tags(self): """Get and reset tags.""" - with self._lock: - tags = self._config_tags - self._reset_config_tags() - return tags + pass def push_config_stats(self): """push config stats to redis.""" - _LOGGER.debug("Adding Config stats to redis key %s" % (self._TELEMETRY_CONFIG_KEY)) - _LOGGER.debug(str(self._format_config_stats())) - self._redis_client.hset(self._TELEMETRY_CONFIG_KEY, self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip, str(self._format_config_stats())) + pass - def _format_config_stats(self): + def _format_config_stats(self, config_stats, tags): """format only selected config stats to json""" - config_stats = self._tel_config.get_stats() return json.dumps({ 'aF': config_stats['aF'], 'rF': config_stats['rF'], 'sT': config_stats['sT'], 'oM': config_stats['oM'], - 't': self.pop_config_tags() + 't': tags }) def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): """Record active and redundant factories.""" - self._tel_config.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + pass def add_latency_to_pipe(self, method, bucket, pipe): """ @@ -703,14 +931,7 @@ def record_exception(self, method): :param method: method name :type method: string """ - _LOGGER.debug("Adding Excepction stats to redis key %s" % (self._TELEMETRY_EXCEPTIONS_KEY)) - _LOGGER.debug(self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + - method.value) - pipe = self._make_pipe() - pipe.hincrby(self._TELEMETRY_EXCEPTIONS_KEY, self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + - method.value, 1) - result = pipe.execute() - self.expire_keys(self._TELEMETRY_EXCEPTIONS_KEY, self._TELEMETRY_KEY_DEFAULT_TTL, 1, result[0]) + pass def record_not_ready_usage(self): """ @@ -730,6 +951,105 @@ def record_impression_stats(self, data_type, count): pass def expire_latency_keys(self, total_keys, inserted): + pass + + def expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + pass + + +class RedisTelemetryStorage(RedisTelemetryStorageBase): + """Redis based telemetry storage class.""" + + def __init__(self, redis_client, sdk_metadata): + """ + Class constructor. + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + """ + self._lock = threading.RLock() + self._reset_config_tags() + self._redis_client = redis_client + self._sdk_metadata = sdk_metadata + self._tel_config = TelemetryConfig() + self._make_pipe = redis_client.pipeline + + def _reset_config_tags(self): + """Reset all config tags""" + with self._lock: + self._config_tags = [] + + def add_config_tag(self, tag): + """Record tag string.""" + with self._lock: + if len(self._config_tags) < MAX_TAGS: + self._config_tags.append(tag) + + def record_config(self, config, extra_config): + """ + initilize telemetry objects + + :param congif: factory configuration parameters + :type config: splitio.client.config + """ + self._tel_config.record_config(config, extra_config) + + def pop_config_tags(self): + """Get and reset tags.""" + with self._lock: + tags = self._config_tags + self._reset_config_tags() + return tags + + def push_config_stats(self): + """push config stats to redis.""" + _LOGGER.debug("Adding Config stats to redis key %s" % (self._TELEMETRY_CONFIG_KEY)) + _LOGGER.debug(str(self._format_config_stats(self._tel_config.get_stats(), self.pop_config_tags()))) + self._redis_client.hset(self._TELEMETRY_CONFIG_KEY, self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip, str(self._format_config_stats(self._tel_config.get_stats(), self.pop_config_tags()))) + + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + self._tel_config.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + def record_exception(self, method): + """ + record an exception + + :param method: method name + :type method: string + """ + _LOGGER.debug("Adding Excepction stats to redis key %s" % (self._TELEMETRY_EXCEPTIONS_KEY)) + _LOGGER.debug(self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + + method.value) + pipe = self._make_pipe() + pipe.hincrby(self._TELEMETRY_EXCEPTIONS_KEY, self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + + method.value, 1) + result = pipe.execute() + self.expire_keys(self._TELEMETRY_EXCEPTIONS_KEY, self._TELEMETRY_KEY_DEFAULT_TTL, 1, result[0]) + + def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + self._tel_config.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + def expire_latency_keys(self, total_keys, inserted): + """ + Expire lstency keys + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ self.expire_keys(self._TELEMETRY_LATENCIES_KEY, self._TELEMETRY_KEY_DEFAULT_TTL, total_keys, inserted) def expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): @@ -743,3 +1063,100 @@ def expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): """ if total_keys == inserted: self._redis_client.expire(queue_key, key_default_ttl) + + +class RedisTelemetryStorageAsync(RedisTelemetryStorageBase): + """Redis based telemetry async storage class.""" + + async def create(redis_client, sdk_metadata): + """ + Create instance and reset tags + + :param redis_client: Redis client or compliant interface. + :type redis_client: splitio.storage.adapters.redis.RedisAdapter + :param sdk_metadata: SDK & Machine information. + :type sdk_metadata: splitio.client.util.SdkMetadata + + :return: self instance. + :rtype: splitio.storage.redis.RedisTelemetryStorageAsync + """ + self = RedisTelemetryStorageAsync() + await self._reset_config_tags() + self._redis_client = redis_client + self._sdk_metadata = sdk_metadata + self._tel_config = await TelemetryConfigAsync.create() + self._make_pipe = redis_client.pipeline + return self + + async def _reset_config_tags(self): + """Reset all config tags""" + self._config_tags = [] + + async def add_config_tag(self, tag): + """Record tag string.""" + if len(self._config_tags) < MAX_TAGS: + self._config_tags.append(tag) + + async def record_config(self, config, extra_config): + """ + initilize telemetry objects + + :param congif: factory configuration parameters + :type config: splitio.client.config + """ + await self._tel_config.record_config(config, extra_config) + + async def pop_config_tags(self): + """Get and reset tags.""" + tags = self._config_tags + await self._reset_config_tags() + return tags + + async def push_config_stats(self): + """push config stats to redis.""" + _LOGGER.debug("Adding Config stats to redis key %s" % (self._TELEMETRY_CONFIG_KEY)) + _LOGGER.debug(str(await self._format_config_stats(await self._tel_config.get_stats(), await self.pop_config_tags()))) + await self._redis_client.hset(self._TELEMETRY_CONFIG_KEY, self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip, str(await self._format_config_stats(await self._tel_config.get_stats(), await self.pop_config_tags()))) + + async def record_exception(self, method): + """ + record an exception + + :param method: method name + :type method: string + """ + _LOGGER.debug("Adding Excepction stats to redis key %s" % (self._TELEMETRY_EXCEPTIONS_KEY)) + _LOGGER.debug(self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + + method.value) + pipe = self._make_pipe() + pipe.hincrby(self._TELEMETRY_EXCEPTIONS_KEY, self._sdk_metadata.sdk_version + '/' + self._sdk_metadata.instance_name + '/' + self._sdk_metadata.instance_ip + '/' + + method.value, 1) + result = await pipe.execute() + await self.expire_keys(self._TELEMETRY_EXCEPTIONS_KEY, self._TELEMETRY_KEY_DEFAULT_TTL, 1, result[0]) + + async def record_active_and_redundant_factories(self, active_factory_count, redundant_factory_count): + """Record active and redundant factories.""" + await self._tel_config.record_active_and_redundant_factories(active_factory_count, redundant_factory_count) + + async def expire_latency_keys(self, total_keys, inserted): + """ + Expire lstency keys + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + await self.expire_keys(self._TELEMETRY_LATENCIES_KEY, self._TELEMETRY_KEY_DEFAULT_TTL, total_keys, inserted) + + async def expire_keys(self, queue_key, key_default_ttl, total_keys, inserted): + """ + Set expire + + :param total_keys: length of keys. + :type total_keys: int + :param inserted: added keys. + :type inserted: int + """ + if total_keys == inserted: + await self._redis_client.expire(queue_key, key_default_ttl) diff --git a/tests/api/test_httpclient.py b/tests/api/test_httpclient.py index 2d9614ab..afcd19cb 100644 --- a/tests/api/test_httpclient.py +++ b/tests/api/test_httpclient.py @@ -223,6 +223,7 @@ async def test_get_custom_urls(self, mocker): assert get_mock.mock_calls == [call] + @pytest.mark.asyncio async def test_post(self, mocker): """Test HTTP POST verb requests.""" response_mock = MockResponse('ok', 200, {}) @@ -255,6 +256,7 @@ async def test_post(self, mocker): assert response.body == 'ok' assert get_mock.mock_calls == [call] + @pytest.mark.asyncio async def test_post_custom_urls(self, mocker): """Test HTTP GET verb requests.""" response_mock = MockResponse('ok', 200, {}) diff --git a/tests/push/test_manager.py b/tests/push/test_manager.py index d2999171..49746b56 100644 --- a/tests/push/test_manager.py +++ b/tests/push/test_manager.py @@ -259,14 +259,14 @@ async def sse_loop_mock(se, token): await asyncio.sleep(1) assert await feedback_loop.get() == Status.PUSH_SUBSYSTEM_UP - assert self.token.push_enabled == True + assert self.token.push_enabled assert self.token.token == 'abc' assert self.token.channels == {} assert self.token.exp == 2000000 assert self.token.iat == 1000000 - assert(telemetry_storage._streaming_events._streaming_events[1]._type == StreamingEventTypes.TOKEN_REFRESH.value) - assert(telemetry_storage._streaming_events._streaming_events[0]._type == StreamingEventTypes.CONNECTION_ESTABLISHED.value) + # assert(telemetry_storage._streaming_events._streaming_events[1]._type == StreamingEventTypes.TOKEN_REFRESH.value) + # assert(telemetry_storage._streaming_events._streaming_events[0]._type == StreamingEventTypes.CONNECTION_ESTABLISHED.value) @pytest.mark.asyncio async def test_connection_failure(self, mocker): @@ -303,9 +303,11 @@ async def authenticate(): sse_constructor_mock.return_value = sse_mock mocker.patch('splitio.push.manager.SplitSSEClientAsync', new=sse_constructor_mock) feedback_loop = asyncio.Queue() - telemetry_storage = InMemoryTelemetryStorage() - telemetry_producer = TelemetryStorageProducer(telemetry_storage) + + telemetry_storage = await InMemoryTelemetryStorageAsync.create() + telemetry_producer = TelemetryStorageProducerAsync(telemetry_storage) telemetry_runtime_producer = telemetry_producer.get_telemetry_runtime_producer() + manager = PushManagerAsync(api_mock, mocker.Mock(), feedback_loop, mocker.Mock(), telemetry_runtime_producer) await manager.start() assert await feedback_loop.get() == Status.PUSH_NONRETRYABLE_ERROR diff --git a/tests/push/test_processor.py b/tests/push/test_processor.py index aa6cf52f..1e25eca3 100644 --- a/tests/push/test_processor.py +++ b/tests/push/test_processor.py @@ -1,8 +1,11 @@ """Message processor tests.""" from queue import Queue -from splitio.push.processor import MessageProcessor -from splitio.sync.synchronizer import Synchronizer +import pytest + +from splitio.push.processor import MessageProcessor, MessageProcessorAsync +from splitio.sync.synchronizer import Synchronizer # , SynchronizerAsync to be added from splitio.push.parser import SplitChangeUpdate, SegmentChangeUpdate, SplitKillUpdate +from splitio.optional.loaders import asyncio class ProcessorTests(object): @@ -56,3 +59,59 @@ def test_segment_change(self, mocker): def test_todo(self): """Fix previous tests so that we validate WHICH queue the update is pushed into.""" assert NotImplementedError("DO THAT") + +class ProcessorAsyncTests(object): + """Message processor test cases.""" + + @pytest.mark.asyncio + async def test_split_change(self, mocker): + """Test split change is properly handled.""" + sync_mock = mocker.Mock(spec=Synchronizer) + self._update = None + async def put_mock(first, event): + self._update = event + + mocker.patch('splitio.push.processor.asyncio.Queue.put', new=put_mock) + processor = MessageProcessorAsync(sync_mock) + update = SplitChangeUpdate('sarasa', 123, 123) + await processor.handle(update) + assert update == self._update + + @pytest.mark.asyncio + async def test_split_kill(self, mocker): + """Test split kill is properly handled.""" + + self._killed_split = None + async def kill_mock(se, split_name, default_treatment, change_number): + self._killed_split = (split_name, default_treatment, change_number) + + mocker.patch('splitio.sync.synchronizer.SynchronizerAsync.kill_split', new=kill_mock) + sync_mock = SynchronizerAsync() + + self._update = None + async def put_mock(first, event): + self._update = event + + mocker.patch('splitio.push.processor.asyncio.Queue.put', new=put_mock) + processor = MessageProcessorAsync(sync_mock) + update = SplitKillUpdate('sarasa', 123, 456, 'some_split', 'off') + await processor.handle(update) + assert update == self._update + assert ('some_split', 'off', 456) == self._killed_split + + @pytest.mark.asyncio + async def test_segment_change(self, mocker): + """Test segment change is properly handled.""" + + sync_mock = SynchronizerAsync() + queue_mock = mocker.Mock(spec=asyncio.Queue) + + self._update = None + async def put_mock(first, event): + self._update = event + + mocker.patch('splitio.push.processor.asyncio.Queue.put', new=put_mock) + processor = MessageProcessorAsync(sync_mock) + update = SegmentChangeUpdate('sarasa', 123, 123, 'some_segment') + await processor.handle(update) + assert update == self._update diff --git a/tests/push/test_splitsse.py b/tests/push/test_splitsse.py index ebb8fa94..fbb12236 100644 --- a/tests/push/test_splitsse.py +++ b/tests/push/test_splitsse.py @@ -5,16 +5,14 @@ import pytest from splitio.models.token import Token - -from splitio.push.splitsse import SplitSSEClient -from splitio.push.sse import SSEEvent +from splitio.push.splitsse import SplitSSEClient, SplitSSEClientAsync +from splitio.push.sse import SSEEvent, SSE_EVENT_ERROR from tests.helpers.mockserver import SSEMockServer - from splitio.client.util import SdkMetadata +from splitio.optional.loaders import asyncio - -class SSEClientTests(object): +class SSESplitClientTests(object): """SSEClient test cases.""" def test_split_sse_success(self): @@ -124,3 +122,86 @@ def on_disconnect(): assert status['on_connect'] assert status['on_disconnect'] + + +class SSESplitClientAsyncTests(object): + """SSEClientAsync test cases.""" + + @pytest.mark.asyncio + async def test_split_sse_success(self): + """Test correct initialization. Client ends the connection.""" + request_queue = Queue() + server = SSEMockServer(request_queue) + server.start() + + client = SplitSSEClientAsync(SdkMetadata('1.0', 'some', '1.2.3.4'), + 'abcd', base_url='http://localhost:' + str(server.port())) + + token = Token(True, 'some', {'chan1': ['subscribe'], 'chan2': ['subscribe', 'channel-metadata:publishers']}, + 1, 2) + + server.publish({'id': '1'}) # send a non-error event early to unblock start + + events_loop = client.start(token) + first_event = await events_loop.__anext__() + assert first_event.event != SSE_EVENT_ERROR + + server.publish({'id': '1', 'data': 'a', 'retry': '1', 'event': 'message'}) + server.publish({'id': '2', 'data': 'a', 'retry': '1', 'event': 'message'}) + await asyncio.sleep(1) + + event2 = await events_loop.__anext__() + event3 = await events_loop.__anext__() + + await client.stop() + + request = request_queue.get(1) + assert request.path == '/event-stream?v=1.1&accessToken=some&channels=chan1,%5B?occupancy=metrics.publishers%5Dchan2' + assert request.headers['accept'] == 'text/event-stream' + assert request.headers['SplitSDKVersion'] == '1.0' + assert request.headers['SplitSDKMachineIP'] == '1.2.3.4' + assert request.headers['SplitSDKMachineName'] == 'some' + assert request.headers['SplitSDKClientKey'] == 'abcd' + + assert event2 == SSEEvent('1', 'message', '1', 'a') + assert event3 == SSEEvent('2', 'message', '1', 'a') + + server.publish(SSEMockServer.VIOLENT_REQUEST_END) + server.stop() + await asyncio.sleep(1) + + assert client.status == SplitSSEClient._Status.IDLE + + + @pytest.mark.asyncio + async def test_split_sse_error(self): + """Test correct initialization. Client ends the connection.""" + request_queue = Queue() + server = SSEMockServer(request_queue) + server.start() + + client = SplitSSEClientAsync(SdkMetadata('1.0', 'some', '1.2.3.4'), + 'abcd', base_url='http://localhost:' + str(server.port())) + + token = Token(True, 'some', {'chan1': ['subscribe'], 'chan2': ['subscribe', 'channel-metadata:publishers']}, + 1, 2) + + events_loop = client.start(token) + server.publish({'event': 'error'}) # send an error event early to unblock start + + await asyncio.sleep(1) + with pytest.raises( StopAsyncIteration): + await events_loop.__anext__() + + assert client.status == SplitSSEClient._Status.IDLE + + request = request_queue.get(1) + assert request.path == '/event-stream?v=1.1&accessToken=some&channels=chan1,%5B?occupancy=metrics.publishers%5Dchan2' + assert request.headers['accept'] == 'text/event-stream' + assert request.headers['SplitSDKVersion'] == '1.0' + assert request.headers['SplitSDKMachineIP'] == '1.2.3.4' + assert request.headers['SplitSDKMachineName'] == 'some' + assert request.headers['SplitSDKClientKey'] == 'abcd' + + server.publish(SSEMockServer.VIOLENT_REQUEST_END) + server.stop() diff --git a/tests/push/test_sse.py b/tests/push/test_sse.py index 62a272ec..4610d961 100644 --- a/tests/push/test_sse.py +++ b/tests/push/test_sse.py @@ -26,7 +26,7 @@ def callback(event): def runner(): """SSE client runner thread.""" assert client.start('http://127.0.0.1:' + str(server.port())) - client_task = threading.Thread(target=runner, daemon=True) + client_task = threading.Thread(target=runner) client_task.setName('client') client_task.start() with pytest.raises(RuntimeError): @@ -65,8 +65,8 @@ def callback(event): def runner(): """SSE client runner thread.""" - assert client.start('http://127.0.0.1:' + str(server.port())) - client_task = threading.Thread(target=runner, daemon=True) + assert not client.start('http://127.0.0.1:' + str(server.port())) + client_task = threading.Thread(target=runner) client_task.setName('client') client_task.start() @@ -102,7 +102,7 @@ def callback(event): def runner(): """SSE client runner thread.""" - assert client.start('http://127.0.0.1:' + str(server.port())) + assert not client.start('http://127.0.0.1:' + str(server.port())) client_task = threading.Thread(target=runner, daemon=True) client_task.setName('client') client_task.start() @@ -128,109 +128,100 @@ def runner(): class SSEClientAsyncTests(object): """SSEClient test cases.""" -# @pytest.mark.asyncio + @pytest.mark.asyncio async def test_sse_client_disconnects(self): """Test correct initialization. Client ends the connection.""" server = SSEMockServer() server.start() + client = SSEClientAsync() + sse_events_loop = client.start(f"http://127.0.0.1:{str(server.port())}?token=abc123$%^&(") + # sse_events_loop = client.start(f"http://127.0.0.1:{str(server.port())}") - events = [] - async def callback(event): - """Callback.""" - events.append(event) - - client = SSEClientAsync(callback) - - async def connect_split_sse_client(): - await client.start('http://127.0.0.1:' + str(server.port())) - - self._client_task = asyncio.gather(connect_split_sse_client()) server.publish({'id': '1'}) server.publish({'id': '2', 'event': 'message', 'data': 'abc'}) server.publish({'id': '3', 'event': 'message', 'data': 'def'}) server.publish({'id': '4', 'event': 'message', 'data': 'ghi'}) + await asyncio.sleep(1) + event1 = await sse_events_loop.__anext__() + event2 = await sse_events_loop.__anext__() + event3 = await sse_events_loop.__anext__() + event4 = await sse_events_loop.__anext__() await client.shutdown() - self._client_task.cancel() await asyncio.sleep(1) - assert events == [ - SSEEvent('1', None, None, None), - SSEEvent('2', 'message', None, 'abc'), - SSEEvent('3', 'message', None, 'def'), - SSEEvent('4', 'message', None, 'ghi') - ] - assert client._conn is None + assert event1 == SSEEvent('1', None, None, None) + assert event2 == SSEEvent('2', 'message', None, 'abc') + assert event3 == SSEEvent('3', 'message', None, 'def') + assert event4 == SSEEvent('4', 'message', None, 'ghi') + assert client._conn.closed + server.publish(server.GRACEFUL_REQUEST_END) server.stop() + @pytest.mark.asyncio async def test_sse_server_disconnects(self): """Test correct initialization. Server ends connection.""" server = SSEMockServer() server.start() + client = SSEClientAsync() + sse_events_loop = client.start('http://127.0.0.1:' + str(server.port())) - events = [] - async def callback(event): - """Callback.""" - events.append(event) - - client = SSEClientAsync(callback) - - async def start_client(): - await client.start('http://127.0.0.1:' + str(server.port())) - - asyncio.gather(start_client()) server.publish({'id': '1'}) server.publish({'id': '2', 'event': 'message', 'data': 'abc'}) server.publish({'id': '3', 'event': 'message', 'data': 'def'}) server.publish({'id': '4', 'event': 'message', 'data': 'ghi'}) - server.publish(server.GRACEFUL_REQUEST_END) await asyncio.sleep(1) - server.stop() - await asyncio.sleep(1) + event1 = await sse_events_loop.__anext__() + event2 = await sse_events_loop.__anext__() + event3 = await sse_events_loop.__anext__() + event4 = await sse_events_loop.__anext__() - assert events == [ - SSEEvent('1', None, None, None), - SSEEvent('2', 'message', None, 'abc'), - SSEEvent('3', 'message', None, 'def'), - SSEEvent('4', 'message', None, 'ghi') - ] + server.publish(server.GRACEFUL_REQUEST_END) + try: + await sse_events_loop.__anext__() + except StopAsyncIteration: + pass + server.stop() + await asyncio.sleep(1) + assert event1 == SSEEvent('1', None, None, None) + assert event2 == SSEEvent('2', 'message', None, 'abc') + assert event3 == SSEEvent('3', 'message', None, 'def') + assert event4 == SSEEvent('4', 'message', None, 'ghi') assert client._conn is None + @pytest.mark.asyncio async def test_sse_server_disconnects_abruptly(self): """Test correct initialization. Server ends connection.""" server = SSEMockServer() server.start() - - events = [] - async def callback(event): - """Callback.""" - events.append(event) - - client = SSEClientAsync(callback) - - async def runner(): - """SSE client runner thread.""" - await client.start('http://127.0.0.1:' + str(server.port())) - - client_task = asyncio.gather(runner()) + client = SSEClientAsync() + sse_events_loop = client.start('http://127.0.0.1:' + str(server.port())) server.publish({'id': '1'}) server.publish({'id': '2', 'event': 'message', 'data': 'abc'}) server.publish({'id': '3', 'event': 'message', 'data': 'def'}) server.publish({'id': '4', 'event': 'message', 'data': 'ghi'}) + await asyncio.sleep(1) + event1 = await sse_events_loop.__anext__() + event2 = await sse_events_loop.__anext__() + event3 = await sse_events_loop.__anext__() + event4 = await sse_events_loop.__anext__() + server.publish(server.VIOLENT_REQUEST_END) - server.stop() - await asyncio.sleep(1) + try: + await sse_events_loop.__anext__() + except StopAsyncIteration: + pass - assert events == [ - SSEEvent('1', None, None, None), - SSEEvent('2', 'message', None, 'abc'), - SSEEvent('3', 'message', None, 'def'), - SSEEvent('4', 'message', None, 'ghi') - ] + server.stop() + await asyncio.sleep(1) + assert event1 == SSEEvent('1', None, None, None) + assert event2 == SSEEvent('2', 'message', None, 'abc') + assert event3 == SSEEvent('3', 'message', None, 'def') + assert event4 == SSEEvent('4', 'message', None, 'ghi') assert client._conn is None diff --git a/tests/storage/adapters/test_cache_trait.py b/tests/storage/adapters/test_cache_trait.py index 15f3b13a..2734d151 100644 --- a/tests/storage/adapters/test_cache_trait.py +++ b/tests/storage/adapters/test_cache_trait.py @@ -6,6 +6,7 @@ import pytest from splitio.storage.adapters import cache_trait +from splitio.optional.loaders import asyncio class CacheTraitTests(object): """Cache trait test cases.""" @@ -130,3 +131,11 @@ def test_decorate(self, mocker): assert cache_trait.decorate(key_func, 0, 10)(user_func) is user_func assert cache_trait.decorate(key_func, 10, 0)(user_func) is user_func assert cache_trait.decorate(key_func, 0, 0)(user_func) is user_func + + @pytest.mark.asyncio + async def test_async_add_and_get_key(self, mocker): + cache = cache_trait.LocalMemoryCache(None, None, 1, 1) + await cache.add_key('split', {'split_name': 'split'}) + assert await cache.get_key('split') == {'split_name': 'split'} + await asyncio.sleep(1) + assert await cache.get_key('split') == None diff --git a/tests/storage/adapters/test_redis_adapter.py b/tests/storage/adapters/test_redis_adapter.py index cb81dfb9..c04cab92 100644 --- a/tests/storage/adapters/test_redis_adapter.py +++ b/tests/storage/adapters/test_redis_adapter.py @@ -1,6 +1,7 @@ """Redis storage adapter test module.""" import pytest +from redis.asyncio.client import Redis as aioredis from splitio.storage.adapters import redis from redis import StrictRedis, Redis from redis.sentinel import Sentinel @@ -184,6 +185,321 @@ def test_sentinel_ssl_fails(self): }) +class RedisStorageAdapterAsyncTests(object): + """Redis storage adapter test cases.""" + + @pytest.mark.asyncio + async def test_forwarding(self, mocker): + """Test that all redis functions forward prefix appropriately.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + + self.arg = None + async def keys(sel, args): + self.arg = args + return ['some_prefix.key1', 'some_prefix.key2'] + mocker.patch('redis.asyncio.client.Redis.keys', new=keys) + await adapter.keys('*') + assert self.arg == 'some_prefix.*' + + self.key = None + self.value = None + async def set(sel, key, value): + self.key = key + self.value = value + mocker.patch('redis.asyncio.client.Redis.set', new=set) + await adapter.set('key1', 'value1') + assert self.key == 'some_prefix.key1' + assert self.value == 'value1' + + self.key = None + async def get(sel, key): + self.key = key + return 'value1' + mocker.patch('redis.asyncio.client.Redis.get', new=get) + await adapter.get('some_key') + assert self.key == 'some_prefix.some_key' + + self.key = None + self.value = None + self.exp = None + async def setex(sel, key, exp, value): + self.key = key + self.value = value + self.exp = exp + mocker.patch('redis.asyncio.client.Redis.setex', new=setex) + await adapter.setex('some_key', 123, 'some_value') + assert self.key == 'some_prefix.some_key' + assert self.exp == 123 + assert self.value == 'some_value' + + self.key = None + async def delete(sel, key): + self.key = key + mocker.patch('redis.asyncio.client.Redis.delete', new=delete) + await adapter.delete('some_key') + assert self.key == 'some_prefix.some_key' + + self.keys = None + async def mget(sel, keys): + self.keys = keys + return ['value1', 'value2', 'value3'] + mocker.patch('redis.asyncio.client.Redis.mget', new=mget) + await adapter.mget(['key1', 'key2', 'key3']) + assert self.keys == ['some_prefix.key1', 'some_prefix.key2', 'some_prefix.key3'] + + self.key = None + self.value = None + self.value2 = None + async def sadd(sel, key, value, value2): + self.key = key + self.value = value + self.value2 = value2 + mocker.patch('redis.asyncio.client.Redis.sadd', new=sadd) + await adapter.sadd('s1', 'value1', 'value2') + assert self.key == 'some_prefix.s1' + assert self.value == 'value1' + assert self.value2 == 'value2' + + self.key = None + self.value = None + self.value2 = None + async def srem(sel, key, value, value2): + self.key = key + self.value = value + self.value2 = value2 + mocker.patch('redis.asyncio.client.Redis.srem', new=srem) + await adapter.srem('s1', 'value1', 'value2') + assert self.key == 'some_prefix.s1' + assert self.value == 'value1' + assert self.value2 == 'value2' + + self.key = None + self.value = None + async def sismember(sel, key, value): + self.key = key + self.value = value + mocker.patch('redis.asyncio.client.Redis.sismember', new=sismember) + await adapter.sismember('s1', 'value1') + assert self.key == 'some_prefix.s1' + assert self.value == 'value1' + + self.key = None + self.key2 = None + self.key3 = None + self.script = None + self.value = None + async def eval(sel, script, value, key, key2, key3): + self.key = key + self.key2 = key2 + self.key3 = key3 + self.script = script + self.value = value + mocker.patch('redis.asyncio.client.Redis.eval', new=eval) + await adapter.eval('script', 3, 'key1', 'key2', 'key3') + assert self.script == 'script' + assert self.value == 3 + assert self.key == 'some_prefix.key1' + assert self.key2 == 'some_prefix.key2' + assert self.key3 == 'some_prefix.key3' + + self.key = None + self.value = None + self.name = None + async def hset(sel, key, name, value): + self.key = key + self.value = value + self.name = name + mocker.patch('redis.asyncio.client.Redis.hset', new=hset) + await adapter.hset('key1', 'name', 'value') + assert self.key == 'some_prefix.key1' + assert self.name == 'name' + assert self.value == 'value' + + self.key = None + self.name = None + async def hget(sel, key, name): + self.key = key + self.name = name + mocker.patch('redis.asyncio.client.Redis.hget', new=hget) + await adapter.hget('key1', 'name') + assert self.key == 'some_prefix.key1' + assert self.name == 'name' + + self.key = None + self.value = None + async def incr(sel, key, value): + self.key = key + self.value = value + mocker.patch('redis.asyncio.client.Redis.incr', new=incr) + await adapter.incr('key1') + assert self.key == 'some_prefix.key1' + assert self.value == 1 + + self.key = None + self.value = None + self.name = None + async def hincrby(sel, key, name, value): + self.key = key + self.value = value + self.name = name + mocker.patch('redis.asyncio.client.Redis.hincrby', new=hincrby) + await adapter.hincrby('key1', 'name1') + assert self.key == 'some_prefix.key1' + assert self.name == 'name1' + assert self.value == 1 + + await adapter.hincrby('key1', 'name1', 5) + assert self.key == 'some_prefix.key1' + assert self.name == 'name1' + assert self.value == 5 + + self.key = None + self.value = None + async def getset(sel, key, value): + self.key = key + self.value = value + mocker.patch('redis.asyncio.client.Redis.getset', new=getset) + await adapter.getset('key1', 'new_value') + assert self.key == 'some_prefix.key1' + assert self.value == 'new_value' + + self.key = None + self.value = None + self.value2 = None + async def rpush(sel, key, value, value2): + self.key = key + self.value = value + self.value2 = value2 + mocker.patch('redis.asyncio.client.Redis.rpush', new=rpush) + await adapter.rpush('key1', 'value1', 'value2') + assert self.key == 'some_prefix.key1' + assert self.value == 'value1' + assert self.value2 == 'value2' + + self.key = None + self.exp = None + async def expire(sel, key, exp): + self.key = key + self.exp = exp + mocker.patch('redis.asyncio.client.Redis.expire', new=expire) + await adapter.expire('key1', 10) + assert self.key == 'some_prefix.key1' + assert self.exp == 10 + + self.key = None + async def rpop(sel, key): + self.key = key + mocker.patch('redis.asyncio.client.Redis.rpop', new=rpop) + await adapter.rpop('key1') + assert self.key == 'some_prefix.key1' + + self.key = None + async def ttl(sel, key): + self.key = key + mocker.patch('redis.asyncio.client.Redis.ttl', new=ttl) + await adapter.ttl('key1') + assert self.key == 'some_prefix.key1' + + @pytest.mark.asyncio + async def test_adapter_building(self, mocker): + """Test buildin different types of client according to parameters received.""" + self.host = None + self.db = None + self.password = None + self.timeout = None + self.socket_connect_timeout = None + self.socket_keepalive = None + self.socket_keepalive_options = None + self.connection_pool = None + self.unix_socket_path = None + self.encoding = None + self.encoding_errors = None + self.errors = None + self.decode_responses = None + self.retry_on_timeout = None + self.ssl = None + self.ssl_keyfile = None + self.ssl_certfile = None + self.ssl_cert_reqs = None + self.ssl_ca_certs = None + self.max_connections = None + async def from_url(host, db, password, timeout, socket_connect_timeout, + socket_keepalive, socket_keepalive_options, connection_pool, + unix_socket_path, encoding, encoding_errors, errors, decode_responses, + retry_on_timeout, ssl, ssl_keyfile, ssl_certfile, ssl_cert_reqs, + ssl_ca_certs, max_connections): + self.host = host + self.db = db + self.password = password + self.timeout = timeout + self.socket_connect_timeout = socket_connect_timeout + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options + self.connection_pool = connection_pool + self.unix_socket_path = unix_socket_path + self.encoding = encoding + self.encoding_errors = encoding_errors + self.errors = errors + self.decode_responses = decode_responses + self.retry_on_timeout = retry_on_timeout + self.ssl = ssl + self.ssl_keyfile = ssl_keyfile + self.ssl_certfile = ssl_certfile + self.ssl_cert_reqs = ssl_cert_reqs + self.ssl_ca_certs = ssl_ca_certs + self.max_connections = max_connections + mocker.patch('redis.asyncio.client.Redis.from_url', new=from_url) + + config = { + 'redisHost': 'some_host', + 'redisPort': 1234, + 'redisDb': 0, + 'redisPassword': 'some_password', + 'redisSocketTimeout': 123, + 'redisSocketConnectTimeout': 456, + 'redisSocketKeepalive': 789, + 'redisSocketKeepaliveOptions': 10, + 'redisConnectionPool': 20, + 'redisUnixSocketPath': '/tmp/socket', + 'redisEncoding': 'utf-8', + 'redisEncodingErrors': 'strict', + 'redisErrors': 'abc', + 'redisDecodeResponses': True, + 'redisRetryOnTimeout': True, + 'redisSsl': True, + 'redisSslKeyfile': '/ssl.cert', + 'redisSslCertfile': '/ssl2.cert', + 'redisSslCertReqs': 'abc', + 'redisSslCaCerts': 'def', + 'redisMaxConnections': 5, + 'redisPrefix': 'some_prefix' + } + + await redis.build_async(config) + + assert self.host == 'redis://some_host:1234' + assert self.db == 0 + assert self.password == 'some_password' + assert self.timeout == 123 + assert self.socket_connect_timeout == 456 + assert self.socket_keepalive == 789 + assert self.socket_keepalive_options == 10 + assert self.connection_pool == 20 + assert self.unix_socket_path == '/tmp/socket' + assert self.encoding == 'utf-8' + assert self.encoding_errors == 'strict' + assert self.errors == 'abc' + assert self.decode_responses == True + assert self.retry_on_timeout == True + assert self.ssl == True + assert self.ssl_keyfile == '/ssl.cert' + assert self.ssl_certfile == '/ssl2.cert' + assert self.ssl_cert_reqs == 'abc' + assert self.ssl_ca_certs == 'def' + assert self.max_connections == 5 + + class RedisPipelineAdapterTests(object): """Redis pipelined adapter test cases.""" @@ -206,3 +522,55 @@ def test_forwarding(self, mocker): adapter.hincrby('key1', 'name1', 5) assert redis_mock_2.hincrby.mock_calls[1] == mocker.call('some_prefix.key1', 'name1', 5) + + +class RedisPipelineAdapterAsyncTests(object): + """Redis pipelined adapter test cases.""" + + @pytest.mark.asyncio + async def test_forwarding(self, mocker): + """Test that all redis functions forward prefix appropriately.""" + redis_mock = await aioredis.from_url("redis://localhost") + prefix_helper = redis.PrefixHelper('some_prefix') + adapter = redis.RedisPipelineAdapterAsync(redis_mock, prefix_helper) + + self.key = None + self.value = None + self.value2 = None + async def rpush(sel, key, value, value2): + self.key = key + self.value = value + self.value2 = value2 + mocker.patch('redis.asyncio.client.Pipeline.rpush', new=rpush) + await adapter.rpush('key1', 'value1', 'value2') + assert self.key == 'some_prefix.key1' + assert self.value == 'value1' + assert self.value2 == 'value2' + + self.key = None + self.value = None + async def incr(sel, key, value): + self.key = key + self.value = value + mocker.patch('redis.asyncio.client.Pipeline.incr', new=incr) + await adapter.incr('key1') + assert self.key == 'some_prefix.key1' + assert self.value == 1 + + self.key = None + self.value = None + self.name = None + async def hincrby(sel, key, name, value): + self.key = key + self.value = value + self.name = name + mocker.patch('redis.asyncio.client.Pipeline.hincrby', new=hincrby) + await adapter.hincrby('key1', 'name1') + assert self.key == 'some_prefix.key1' + assert self.name == 'name1' + assert self.value == 1 + + await adapter.hincrby('key1', 'name1', 5) + assert self.key == 'some_prefix.key1' + assert self.name == 'name1' + assert self.value == 5 diff --git a/tests/storage/test_redis.py b/tests/storage/test_redis.py index 33fef5a6..5a5637d9 100644 --- a/tests/storage/test_redis.py +++ b/tests/storage/test_redis.py @@ -4,16 +4,20 @@ import json import time import unittest.mock as mock +import redis.asyncio as aioredis import pytest from splitio.client.util import get_metadata, SdkMetadata -from splitio.storage.redis import RedisEventsStorage, RedisImpressionsStorage, \ - RedisSegmentStorage, RedisSplitStorage, RedisTelemetryStorage +from splitio.optional.loaders import asyncio +from splitio.storage.redis import RedisEventsStorage, RedisImpressionsStorage, RedisImpressionsStorageAsync, \ + RedisSegmentStorage, RedisSplitStorage, RedisSplitStorageAsync, RedisTelemetryStorage, RedisTelemetryStorageAsync from splitio.storage.adapters.redis import RedisAdapter, RedisAdapterException, build +from redis.asyncio.client import Redis as aioredis +from splitio.storage.adapters import redis from splitio.models.segments import Segment from splitio.models.impressions import Impression from splitio.models.events import Event, EventWrapper -from splitio.models.telemetry import MethodExceptions, MethodLatencies, TelemetryConfig, MethodExceptionsAndLatencies +from splitio.models.telemetry import MethodExceptions, MethodLatencies, TelemetryConfig, MethodExceptionsAndLatencies, TelemetryConfigAsync class RedisSplitStorageTests(object): @@ -172,6 +176,259 @@ def test_is_valid_traffic_type_with_cache(self, mocker): time.sleep(1) assert storage.is_valid_traffic_type('any') is False +class RedisSplitStorageAsyncTests(object): + """Redis split storage test cases.""" + + @pytest.mark.asyncio + async def test_get_split(self, mocker): + """Test retrieving a split works.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + + self.redis_ret = None + self.name = None + async def get(sel, name): + self.name = name + self.redis_ret = '{"name": "some_split"}' + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.splits.from_raw', new=from_raw) + + storage = RedisSplitStorageAsync(adapter) + await storage.get('some_split') + + assert self.name == 'SPLITIO.split.some_split' + assert self.redis_ret == '{"name": "some_split"}' + + # Test that a missing split returns None and doesn't call from_raw + from_raw.reset_mock() + self.name = None + async def get2(sel, name): + self.name = name + return None + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get2) + + result = await storage.get('some_split') + assert result is None + assert self.name == 'SPLITIO.split.some_split' + assert not from_raw.mock_calls + + @pytest.mark.asyncio + async def test_get_split_with_cache(self, mocker): + """Test retrieving a split works.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + + self.redis_ret = None + self.name = None + async def get(sel, name): + self.name = name + self.redis_ret = '{"name": "some_split"}' + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.splits.from_raw', new=from_raw) + + storage = RedisSplitStorageAsync(adapter, True, 1) + await storage.get('some_split') + assert self.name == 'SPLITIO.split.some_split' + assert self.redis_ret == '{"name": "some_split"}' + + # hit the cache: + self.name = None + await storage.get('some_split') + self.name = None + await storage.get('some_split') + self.name = None + await storage.get('some_split') + assert self.name == None + + # Test that a missing split returns None and doesn't call from_raw + from_raw.reset_mock() + self.name = None + async def get2(sel, name): + self.name = name + return None + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get2) + + # Still cached + result = await storage.get('some_split') + assert result is not None + assert self.name == None + await asyncio.sleep(1) # wait for expiration + result = await storage.get('some_split') + assert self.name == 'SPLITIO.split.some_split' + assert result is None + + @pytest.mark.asyncio + async def test_get_splits_with_cache(self, mocker): + """Test retrieving a list of passed splits.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter, True, 1) + + self.redis_ret = None + self.name = None + async def mget(sel, name): + self.name = name + self.redis_ret = ['{"name": "split1"}', '{"name": "split2"}', None] + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.mget', new=mget) + + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.splits.from_raw', new=from_raw) + + result = await storage.fetch_many(['split1', 'split2', 'split3']) + assert len(result) == 3 + + assert '{"name": "split1"}' in self.redis_ret + assert '{"name": "split2"}' in self.redis_ret + + assert result['split1'] is not None + assert result['split2'] is not None + assert 'split3' in result + + # fetch again + self.name = None + result = await storage.fetch_many(['split1', 'split2', 'split3']) + assert result['split1'] is not None + assert result['split2'] is not None + assert 'split3' in result + assert self.name == None + + # wait for expire + await asyncio.sleep(1) + self.name = None + result = await storage.fetch_many(['split1', 'split2', 'split3']) + assert self.name == ['SPLITIO.split.split1', 'SPLITIO.split.split2', 'SPLITIO.split.split3'] + + @pytest.mark.asyncio + async def test_get_changenumber(self, mocker): + """Test fetching changenumber.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter) + + self.redis_ret = None + self.name = None + async def get(sel, name): + self.name = name + self.redis_ret = '-1' + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + + assert await storage.get_change_number() == -1 + assert self.name == 'SPLITIO.splits.till' + + @pytest.mark.asyncio + async def test_get_all_splits(self, mocker): + """Test fetching all splits.""" + from_raw = mocker.Mock() + mocker.patch('splitio.storage.redis.splits.from_raw', new=from_raw) + + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter) + + self.redis_ret = None + self.name = None + async def mget(sel, name): + self.name = name + self.redis_ret = ['{"name": "split1"}', '{"name": "split2"}', '{"name": "split3"}'] + return self.redis_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.mget', new=mget) + + self.key = None + self.keys_ret = None + async def keys(sel, key): + self.key = key + self.keys_ret = [ + 'SPLITIO.split.split1', + 'SPLITIO.split.split2', + 'SPLITIO.split.split3' + ] + return self.keys_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.keys', new=keys) + + await storage.get_all_splits() + + assert self.key == 'SPLITIO.split.*' + assert self.keys_ret == ['SPLITIO.split.split1', 'SPLITIO.split.split2', 'SPLITIO.split.split3'] + assert len(from_raw.mock_calls) == 3 + assert mocker.call({'name': 'split1'}) in from_raw.mock_calls + assert mocker.call({'name': 'split2'}) in from_raw.mock_calls + assert mocker.call({'name': 'split3'}) in from_raw.mock_calls + + @pytest.mark.asyncio + async def test_get_split_names(self, mocker): + """Test getching split names.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter) + + self.key = None + self.keys_ret = None + async def keys(sel, key): + self.key = key + self.keys_ret = [ + 'SPLITIO.split.split1', + 'SPLITIO.split.split2', + 'SPLITIO.split.split3' + ] + return self.keys_ret + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.keys', new=keys) + + assert await storage.get_split_names() == ['split1', 'split2', 'split3'] + + @pytest.mark.asyncio + async def test_is_valid_traffic_type(self, mocker): + """Test that traffic type validation works.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter) + + async def get(sel, name): + return '1' + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + assert await storage.is_valid_traffic_type('any') is True + + async def get2(sel, name): + return '0' + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get2) + assert await storage.is_valid_traffic_type('any') is False + + async def get3(sel, name): + return None + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get3) + assert await storage.is_valid_traffic_type('any') is False + + @pytest.mark.asyncio + async def test_is_valid_traffic_type_with_cache(self, mocker): + """Test that traffic type validation works.""" + redis_mock = await aioredis.from_url("redis://localhost") + adapter = redis.RedisAdapterAsync(redis_mock, 'some_prefix') + storage = RedisSplitStorageAsync(adapter, True, 1) + + async def get(sel, name): + return '1' + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get) + assert await storage.is_valid_traffic_type('any') is True + + async def get2(sel, name): + return '0' + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get2) + assert await storage.is_valid_traffic_type('any') is True + await asyncio.sleep(1) + assert await storage.is_valid_traffic_type('any') is False + + async def get3(sel, name): + return None + mocker.patch('splitio.storage.adapters.redis.RedisAdapterAsync.get', new=get3) + await asyncio.sleep(1) + assert await storage.is_valid_traffic_type('any') is False class RedisSegmentStorageTests(object): """Redis segment storage test cases.""" @@ -334,6 +591,167 @@ def test_add_impressions_to_pipe(self, mocker): storage.add_impressions_to_pipe(impressions, adapter) assert adapter.rpush.mock_calls == [mocker.call('SPLITIO.impressions', *to_validate)] + def test_expire_key(self, mocker): + adapter = mocker.Mock(spec=RedisAdapter) + metadata = get_metadata({}) + storage = RedisImpressionsStorage(adapter, metadata) + + self.key = None + self.ttl = None + def expire(key, ttl): + self.key = key + self.ttl = ttl + adapter.expire = expire + + storage.expire_key(2, 2) + assert self.key == 'SPLITIO.impressions' + assert self.ttl == 3600 + + self.key = None + storage.expire_key(2, 1) + assert self.key == None + + +class RedisImpressionsStorageAsyncTests(object): # pylint: disable=too-few-public-methods + """Redis Impressions async storage test cases.""" + + def test_wrap_impressions(self, mocker): + """Test wrap impressions.""" + adapter = mocker.Mock(spec=RedisAdapterAsync) + metadata = get_metadata({}) + storage = RedisImpressionsStorageAsync(adapter, metadata) + + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654) + ] + + to_validate = [json.dumps({ + 'm': { # METADATA PORTION + 's': metadata.sdk_version, + 'n': metadata.instance_name, + 'i': metadata.instance_ip, + }, + 'i': { # IMPRESSION PORTION + 'k': impression.matching_key, + 'b': impression.bucketing_key, + 'f': impression.feature_name, + 't': impression.treatment, + 'r': impression.label, + 'c': impression.change_number, + 'm': impression.time, + } + }) for impression in impressions] + + assert storage._wrap_impressions(impressions) == to_validate + + @pytest.mark.asyncio + async def test_add_impressions(self, mocker): + """Test that adding impressions to storage works.""" + adapter = mocker.Mock(spec=RedisAdapterAsync) + metadata = get_metadata({}) + storage = RedisImpressionsStorageAsync(adapter, metadata) + + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654) + ] + self.key = None + self.imps = None + async def rpush(key, *imps): + self.key = key + self.imps = imps + + adapter.rpush = rpush + assert await storage.put(impressions) is True + + to_validate = [json.dumps({ + 'm': { # METADATA PORTION + 's': metadata.sdk_version, + 'n': metadata.instance_name, + 'i': metadata.instance_ip, + }, + 'i': { # IMPRESSION PORTION + 'k': impression.matching_key, + 'b': impression.bucketing_key, + 'f': impression.feature_name, + 't': impression.treatment, + 'r': impression.label, + 'c': impression.change_number, + 'm': impression.time, + } + }) for impression in impressions] + + assert self.key == 'SPLITIO.impressions' + assert self.imps == tuple(to_validate) + + # Assert that if an exception is thrown it's caught and False is returned + adapter.reset_mock() + + async def rpush2(key, *imps): + raise RedisAdapterException('something') + adapter.rpush = rpush2 + assert await storage.put(impressions) is False + + def test_add_impressions_to_pipe(self, mocker): + """Test that adding impressions to storage works.""" + adapter = mocker.Mock(spec=RedisAdapterAsync) + metadata = get_metadata({}) + storage = RedisImpressionsStorageAsync(adapter, metadata) + + impressions = [ + Impression('key1', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key2', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key3', 'feature2', 'on', 'some_label', 123456, 'buck1', 321654), + Impression('key4', 'feature1', 'on', 'some_label', 123456, 'buck1', 321654) + ] + + to_validate = [json.dumps({ + 'm': { # METADATA PORTION + 's': metadata.sdk_version, + 'n': metadata.instance_name, + 'i': metadata.instance_ip, + }, + 'i': { # IMPRESSION PORTION + 'k': impression.matching_key, + 'b': impression.bucketing_key, + 'f': impression.feature_name, + 't': impression.treatment, + 'r': impression.label, + 'c': impression.change_number, + 'm': impression.time, + } + }) for impression in impressions] + + storage.add_impressions_to_pipe(impressions, adapter) + assert adapter.rpush.mock_calls == [mocker.call('SPLITIO.impressions', *to_validate)] + + @pytest.mark.asyncio + async def test_expire_key(self, mocker): + adapter = mocker.Mock(spec=RedisAdapterAsync) + metadata = get_metadata({}) + storage = RedisImpressionsStorageAsync(adapter, metadata) + + self.key = None + self.ttl = None + async def expire(key, ttl): + self.key = key + self.ttl = ttl + adapter.expire = expire + + await storage.expire_key(2, 2) + assert self.key == 'SPLITIO.impressions' + assert self.ttl == 3600 + + self.key = None + await storage.expire_key(2, 1) + assert self.key == None + + class RedisEventsStorageTests(object): # pylint: disable=too-few-public-methods """Redis Impression storage test cases.""" @@ -485,3 +903,140 @@ def test_expire_keys(self, mocker): assert(not mocker.called) redis_telemetry.expire_keys('key', 12, 2, 2) assert(mocker.called) + + +class RedisTelemetryStorageAsyncTests(object): + """Redis Telemetry storage test cases.""" + + @pytest.mark.asyncio + async def test_init(self, mocker): + redis_telemetry = await RedisTelemetryStorageAsync.create(mocker.Mock(), mocker.Mock()) + assert(redis_telemetry._redis_client is not None) + assert(redis_telemetry._sdk_metadata is not None) + assert(isinstance(redis_telemetry._tel_config, TelemetryConfigAsync)) + assert(redis_telemetry._make_pipe is not None) + + @pytest.mark.asyncio + async def test_record_config(self, mocker): + redis_telemetry = await RedisTelemetryStorageAsync.create(mocker.Mock(), mocker.Mock()) + self.called = False + async def record_config(*args): + self.called = True + redis_telemetry._tel_config.record_config = record_config + + await redis_telemetry.record_config(mocker.Mock(), mocker.Mock()) + assert(self.called) + + @pytest.mark.asyncio + async def test_push_config_stats(self, mocker): + adapter = await aioredis.from_url("redis://localhost") + redis_telemetry = await RedisTelemetryStorageAsync.create(adapter, SdkMetadata('python-1.1.1', 'hostname', 'ip')) + self.key = None + self.hash = None + async def hset(key, hash, val): + self.key = key + self.hash = hash + + adapter.hset = hset + async def format_config_stats(stats, tags): + return "" + redis_telemetry._format_config_stats = format_config_stats + await redis_telemetry.push_config_stats() + assert self.key == 'SPLITIO.telemetry.init' + assert self.hash == 'python-1.1.1/hostname/ip' + + @pytest.mark.asyncio + async def test_format_config_stats(self, mocker): + redis_telemetry = await RedisTelemetryStorageAsync.create(mocker.Mock(), mocker.Mock()) + json_value = redis_telemetry._format_config_stats({'aF': 0, 'rF': 0, 'sT': None, 'oM': None}, []) + stats = await redis_telemetry._tel_config.get_stats() + assert(json_value == json.dumps({ + 'aF': stats['aF'], + 'rF': stats['rF'], + 'sT': stats['sT'], + 'oM': stats['oM'], + 't': await redis_telemetry.pop_config_tags() + })) + + @pytest.mark.asyncio + async def test_record_active_and_redundant_factories(self, mocker): + redis_telemetry = await RedisTelemetryStorageAsync.create(mocker.Mock(), mocker.Mock()) + active_factory_count = 1 + redundant_factory_count = 2 + await redis_telemetry.record_active_and_redundant_factories(1, 2) + assert (redis_telemetry._tel_config._active_factory_count == active_factory_count) + assert (redis_telemetry._tel_config._redundant_factory_count == redundant_factory_count) + + @pytest.mark.asyncio + async def test_add_latency_to_pipe(self, mocker): + adapter = build({}) + metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + redis_telemetry = await RedisTelemetryStorageAsync.create(adapter, metadata) + pipe = adapter._decorated.pipeline() + + def _mocked_hincrby(*args, **kwargs): + assert(args[1] == RedisTelemetryStorageAsync._TELEMETRY_LATENCIES_KEY) + assert(args[2][-11:] == 'treatment/0') + assert(args[3] == 1) + # should increment bucket 0 + with mock.patch('redis.client.Pipeline.hincrby', _mocked_hincrby): + redis_telemetry.add_latency_to_pipe(MethodExceptionsAndLatencies.TREATMENT, 0, pipe) + + def _mocked_hincrby2(*args, **kwargs): + assert(args[1] == RedisTelemetryStorageAsync._TELEMETRY_LATENCIES_KEY) + assert(args[2][-11:] == 'treatment/3') + assert(args[3] == 1) + # should increment bucket 3 + with mock.patch('redis.client.Pipeline.hincrby', _mocked_hincrby2): + redis_telemetry.add_latency_to_pipe(MethodExceptionsAndLatencies.TREATMENT, 3, pipe) + + @pytest.mark.asyncio + async def test_record_exception(self, mocker): + self.called = False + def _mocked_hincrby(*args, **kwargs): + self.called = True + assert(args[1] == RedisTelemetryStorageAsync._TELEMETRY_EXCEPTIONS_KEY) + assert(args[2] == 'python-1.1.1/hostname/ip/treatment') + assert(args[3] == 1) + + self.called2 = False + async def _mocked_execute(*args): + self.called2 = True + return [1] + + adapter = await aioredis.from_url("redis://localhost") + metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + redis_telemetry = await RedisTelemetryStorageAsync.create(adapter, metadata) + with mock.patch('redis.asyncio.client.Pipeline.hincrby', _mocked_hincrby): + with mock.patch('redis.asyncio.client.Pipeline.execute', _mocked_execute): + await redis_telemetry.record_exception(MethodExceptionsAndLatencies.TREATMENT) + assert self.called + assert self.called2 + + @pytest.mark.asyncio + async def test_expire_latency_keys(self, mocker): + redis_telemetry = await RedisTelemetryStorageAsync.create(mocker.Mock(), mocker.Mock()) + def _mocked_method(*args, **kwargs): + assert(args[1] == RedisTelemetryStorageAsync._TELEMETRY_LATENCIES_KEY) + assert(args[2] == RedisTelemetryStorageAsync._TELEMETRY_KEY_DEFAULT_TTL) + assert(args[3] == 1) + assert(args[4] == 2) + + with mock.patch('splitio.storage.redis.RedisTelemetryStorage.expire_keys', _mocked_method): + await redis_telemetry.expire_latency_keys(1, 2) + + @pytest.mark.asyncio + async def test_expire_keys(self, mocker): + adapter = await aioredis.from_url("redis://localhost") + metadata = SdkMetadata('python-1.1.1', 'hostname', 'ip') + redis_telemetry = await RedisTelemetryStorageAsync.create(adapter, metadata) + self.called = False + async def expire(*args): + self.called = True + adapter.expire = expire + + await redis_telemetry.expire_keys('key', 12, 1, 2) + assert(not self.called) + + await redis_telemetry.expire_keys('key', 12, 2, 2) + assert(self.called)