diff --git a/discovery-provider/src/tasks/index.py b/discovery-provider/src/tasks/index.py index 9b07ab1c6a4..5a055862587 100644 --- a/discovery-provider/src/tasks/index.py +++ b/discovery-provider/src/tasks/index.py @@ -3,10 +3,11 @@ from src.models import Block, User, Track, Repost, Follow, Playlist, Save from src.tasks.celery_app import celery from src.tasks.tracks import track_state_update -from src.tasks.users import user_state_update, get_ipfs_info_from_cnode_endpoint # pylint: disable=E0611,E0001 +from src.tasks.users import user_state_update # pylint: disable=E0611,E0001 from src.tasks.social_features import social_feature_state_update from src.tasks.playlists import playlist_state_update from src.tasks.user_library import user_library_state_update +from src.utils.helpers import get_ipfs_info_from_cnode_endpoint logger = logging.getLogger(__name__) diff --git a/discovery-provider/src/tasks/tracks.py b/discovery-provider/src/tasks/tracks.py index bb0e9fe7fdb..a525fd9eafa 100644 --- a/discovery-provider/src/tasks/tracks.py +++ b/discovery-provider/src/tasks/tracks.py @@ -2,8 +2,8 @@ from datetime import datetime from sqlalchemy.orm.session import make_transient from src import contract_addresses -from src.utils import multihash -from src.models import Track, BlacklistedIPLD +from src.utils import multihash, helpers +from src.models import Track, User, BlacklistedIPLD from src.tasks.metadata import track_metadata_format logger = logging.getLogger(__name__) @@ -143,6 +143,10 @@ def parse_track_event( return track_record track_record.owner_id = event_args._trackOwnerId + + # Reconnect to creator nodes for this user + refresh_track_owner_ipfs_conn(track_record.owner_id, session, update_task) + track_record.is_delete = False track_metadata = update_task.ipfs_client.get_metadata( track_metadata_multihash, @@ -185,6 +189,10 @@ def parse_track_event( track_record.owner_id = event_args._trackOwnerId track_record.is_delete = False + + # Reconnect to creator nodes for this user + refresh_track_owner_ipfs_conn(track_record.owner_id, session, update_task) + track_metadata = update_task.ipfs_client.get_metadata( upd_track_metadata_multihash, track_metadata_format @@ -243,3 +251,19 @@ def populate_track_record_metadata(track_record, track_metadata): track_record.iswc = track_metadata["iswc"] track_record.track_segments = track_metadata["track_segments"] return track_record + +def refresh_track_owner_ipfs_conn(owner_id, session, update_task): + owner_record = ( + session.query(User.creator_node_endpoint) + .filter( + User.is_current == True, + User.is_ready == True, + User.user_id == owner_id) + .all() + ) + if len(owner_record) >= 1: + parsed_endpoint_list = owner_record[0][0] + helpers.update_ipfs_peers_from_user_endpoint( + update_task, + parsed_endpoint_list + ) diff --git a/discovery-provider/src/tasks/users.py b/discovery-provider/src/tasks/users.py index e8b81db62e6..c731944c8b1 100644 --- a/discovery-provider/src/tasks/users.py +++ b/discovery-provider/src/tasks/users.py @@ -1,8 +1,6 @@ import logging -from urllib.parse import urljoin from datetime import datetime from sqlalchemy.orm.session import make_transient -import requests from src import contract_addresses from src.utils import helpers from src.models import User, BlacklistedIPLD @@ -246,7 +244,7 @@ def get_metadata_overrides_from_ipfs(session, update_task, user_record): return None # Manually peer with user creator nodes - update_ipfs_peers_from_user_endpoint( + helpers.update_ipfs_peers_from_user_endpoint( update_task, user_record.creator_node_endpoint ) @@ -257,31 +255,3 @@ def get_metadata_overrides_from_ipfs(session, update_task, user_record): ) return user_metadata - - -def get_ipfs_info_from_cnode_endpoint(url): - id_url = urljoin(url, 'ipfs_peer_info') - resp = requests.get(id_url, timeout=5) - json_resp = resp.json() - if 'addresses' in json_resp and isinstance(json_resp['addresses'], list): - for multiaddr in json_resp['addresses']: - if ('127.0.0.1' not in multiaddr) and ('ip6' not in multiaddr): - return multiaddr - raise Exception('Failed to find valid multiaddr') - - -def update_ipfs_peers_from_user_endpoint(update_task, cnode_url_list): - if cnode_url_list is None: - return - redis = update_task.redis - cnode_entries = cnode_url_list.split(',') - interval = int(update_task.shared_config["discprov"]["peer_refresh_interval"]) - for cnode_url in cnode_entries: - if cnode_url == '': - continue - try: - multiaddr = get_ipfs_info_from_cnode_endpoint(cnode_url) - update_task.ipfs_client.connect_peer(multiaddr) - redis.set(cnode_url, multiaddr, interval) - except Exception as e: # pylint: disable=broad-except - logger.warning(f"Error retrieving info for {cnode_url}, {e}") diff --git a/discovery-provider/src/utils/helpers.py b/discovery-provider/src/utils/helpers.py index 80f8c471823..135239dd28a 100644 --- a/discovery-provider/src/utils/helpers.py +++ b/discovery-provider/src/utils/helpers.py @@ -1,10 +1,11 @@ +import logging import os import json -import logging import contextlib +from urllib.parse import urljoin +import requests from . import multihash - @contextlib.contextmanager def cd(path): """Context manager that changes to directory `path` and return to CWD @@ -120,3 +121,30 @@ def get_discovery_provider_version(): with open(versionFilePath) as f: data = json.load(f) return data + +def get_ipfs_info_from_cnode_endpoint(url): + id_url = urljoin(url, 'ipfs_peer_info') + resp = requests.get(id_url, timeout=5) + json_resp = resp.json() + if 'addresses' in json_resp and isinstance(json_resp['addresses'], list): + for multiaddr in json_resp['addresses']: + if ('127.0.0.1' not in multiaddr) and ('ip6' not in multiaddr): + return multiaddr + raise Exception('Failed to find valid multiaddr') + +def update_ipfs_peers_from_user_endpoint(update_task, cnode_url_list): + logger = logging.getLogger(__name__) + if cnode_url_list is None: + return + redis = update_task.redis + cnode_entries = cnode_url_list.split(',') + interval = int(update_task.shared_config["discprov"]["peer_refresh_interval"]) + for cnode_url in cnode_entries: + if cnode_url == '': + continue + try: + multiaddr = get_ipfs_info_from_cnode_endpoint(cnode_url) + update_task.ipfs_client.connect_peer(multiaddr) + redis.set(cnode_url, multiaddr, interval) + except Exception as e: # pylint: disable=broad-except + logger.warning(f"Error connecting to {cnode_url}, {e}") diff --git a/discovery-provider/src/utils/ipfs_lib.py b/discovery-provider/src/utils/ipfs_lib.py index db6f2f9b915..2fc270e7e9a 100644 --- a/discovery-provider/src/utils/ipfs_lib.py +++ b/discovery-provider/src/utils/ipfs_lib.py @@ -2,9 +2,9 @@ import json import time from urllib.parse import urlparse -import ipfshttpclient import requests from requests.exceptions import ReadTimeout +import ipfshttpclient logger = logging.getLogger(__name__)