diff --git a/discovery-provider/src/queries/query_helpers.py b/discovery-provider/src/queries/query_helpers.py index a02946db5ef..1f104e971fa 100644 --- a/discovery-provider/src/queries/query_helpers.py +++ b/discovery-provider/src/queries/query_helpers.py @@ -24,6 +24,28 @@ defaultOffset = 0 minOffset = 0 +# Used when generating genre list to special case Electronic tunes +electronic_sub_genres = [ + 'Techno', + 'Trap', + 'House', + 'Tech House', + 'Deep House', + 'Disco', + 'Electro', + 'Jungle', + 'Progressive House', + 'Hardstyle', + 'Glitch Hop', + 'Trance', + 'Future Bass', + 'Future House', + 'Tropical House', + 'Downtempo', + 'Drum & Bass', + 'Dubstep', + 'Jersey Club', +] ######## HELPERS ######## @@ -591,3 +613,11 @@ def paginate_query(query_obj, apply_offset=True): (limit, offset) = get_pagination_vars() query_obj = query_obj.limit(limit) return query_obj.offset(offset) if apply_offset else query_obj + +def get_genre_list(genre): + genre_list = [] + genre_list.append(genre) + if genre == 'Electronic': + genre_list = genre_list + electronic_sub_genres + return genre_list + diff --git a/discovery-provider/src/queries/trending.py b/discovery-provider/src/queries/trending.py index 9861a74f2a0..c312260388a 100644 --- a/discovery-provider/src/queries/trending.py +++ b/discovery-provider/src/queries/trending.py @@ -4,14 +4,14 @@ from sqlalchemy import func from flask import Blueprint, request -from urllib.parse import urljoin +from urllib.parse import urljoin, unquote from src import api_helpers from src.models import User, Track, RepostType, Follow, SaveType from src.utils.db_session import get_db from src.utils.config import shared_config from src.queries import response_name_constants -from src.queries.query_helpers import get_repost_counts, get_pagination_vars, get_save_counts +from src.queries.query_helpers import get_repost_counts, get_pagination_vars, get_save_counts, get_genre_list logger = logging.getLogger(__name__) bp = Blueprint("trending", __name__) @@ -23,20 +23,41 @@ def trending(time): identity_url = shared_config['discprov']['identity_service_url'] identity_trending_endpoint = urljoin(identity_url, f"/tracks/trending/{time}") + db = get_db() (limit, offset) = get_pagination_vars() - queryparams = {} - queryparams["limit"] = limit - queryparams["offset"] = offset + post_body = {} + post_body["limit"] = limit + post_body["offset"] = offset + + # Retrieve genre and query all tracks if required + genre = request.args.get("genre", default=None, type=str) + if genre is not None: + # Parse encoded characters, such as Hip-Hop%252FRap -> Hip-Hop/Rap + genre = unquote(genre) + with db.scoped_session() as session: + genre_list = get_genre_list(genre) + genre_track_ids = ( + session.query(Track.track_id) + .filter( + Track.genre.in_(genre_list), + Track.is_current == True, + Track.is_delete == False + ) + .all() + ) + genre_specific_track_ids = [record[0] for record in genre_track_ids] + post_body["track_ids"] = genre_specific_track_ids + # Query trending information from identity service resp = None try: - resp = requests.get(identity_trending_endpoint, params=queryparams) + resp = requests.post(identity_trending_endpoint, json=post_body) except Exception as e: logger.error( - f'Error retrieving trending info - {identity_trending_endpoint}, {queryparams}' + f'Error retrieving trending info - {identity_trending_endpoint}, {post_body}' ) - raise e + return api_helpers.error_response(e, 500) json_resp = resp.json() if "error" in json_resp: @@ -50,7 +71,6 @@ def trending(time): track_ids = [track[response_name_constants.track_id] for track in listen_counts] - db = get_db() with db.scoped_session() as session: # Filter tracks to not-deleted ones so trending order is preserved not_deleted_track_ids = ( @@ -113,7 +133,7 @@ def trending(time): if save_type == SaveType.track } - trending = [] + trending_tracks = [] for track_entry in listen_counts: # Skip over deleted tracks if (track_entry[response_name_constants.track_id] not in not_deleted_track_ids): @@ -125,7 +145,7 @@ def trending(time): track_repost_counts[track_entry[response_name_constants.track_id]] else: track_entry[response_name_constants.repost_count] = 0 - + # Populate save counts if track_entry[response_name_constants.track_id] in track_save_counts: track_entry[response_name_constants.save_count] = \ @@ -141,9 +161,8 @@ def trending(time): track_entry[response_name_constants.track_owner_id] = owner_id track_entry[response_name_constants.track_owner_follower_count] = owner_follow_count - trending.append(track_entry) + trending_tracks.append(track_entry) final_resp = {} - final_resp['listen_counts'] = trending + final_resp['listen_counts'] = trending_tracks return api_helpers.success_response(final_resp) - diff --git a/identity-service/src/routes/trackListens.js b/identity-service/src/routes/trackListens.js index 4f745d6293c..139d3a75c88 100644 --- a/identity-service/src/routes/trackListens.js +++ b/identity-service/src/routes/trackListens.js @@ -10,10 +10,25 @@ async function getListenHour () { return listenDate } -let oneDayInMs = (24 * 60 * 60 * 1000) -let oneWeekInMs = oneDayInMs * 7 -let oneMonthInMs = oneDayInMs * 30 -let oneYearInMs = oneMonthInMs * 12 +const oneDayInMs = (24 * 60 * 60 * 1000) +const oneWeekInMs = oneDayInMs * 7 +const oneMonthInMs = oneDayInMs * 30 +const oneYearInMs = oneMonthInMs * 12 + +// Limit / offset related constants +const defaultLimit = 100 +const minLimit = 1 +const maxLimit = 500 +const defaultOffset = 0 +const minOffset = 0 + +const getPaginationVars = (limit, offset) => { + if (!limit) limit = defaultLimit + if (!offset) offset = defaultOffset + let boundedLimit = Math.min(Math.max(limit, minLimit), maxLimit) + let boundedOffset = Math.max(offset, minOffset) + return { limit: boundedLimit, offset: boundedOffset } +} const parseTimeframe = (inputTime) => { switch (inputTime) { @@ -116,6 +131,87 @@ const getTrackListens = async ( return output } +const getTrendingTracks = async ( + idList, + timeFrame, + limit, + offset) => { + if (idList !== undefined && !Array.isArray(idList)) { + return errorResponseBadRequest('Invalid id list provided. Please provide an array of track IDs') + } + + let dbQuery = { + attributes: ['trackId', [models.Sequelize.fn('sum', models.Sequelize.col('listens')), 'listens']], + group: ['trackId'], + order: [[models.Sequelize.col('listens'), 'DESC'], [models.Sequelize.col('trackId'), 'DESC']], + where: {} + } + + // If id list present, add filter + if (idList && idList.length > 0) { + dbQuery.where.trackId = { [models.Sequelize.Op.in]: idList } + } + + let currentHour = await getListenHour() + switch (timeFrame) { + case 'day': + let oneDayBefore = new Date(currentHour.getTime() - oneDayInMs) + dbQuery.where.hour = { [models.Sequelize.Op.gte]: oneDayBefore } + break + case 'week': + let oneWeekBefore = new Date(currentHour.getTime() - oneWeekInMs) + dbQuery.where.hour = { [models.Sequelize.Op.gte]: oneWeekBefore } + break + case 'month': + let oneMonthBefore = new Date(currentHour.getTime() - oneMonthInMs) + dbQuery.where.hour = { [models.Sequelize.Op.gte]: oneMonthBefore } + break + case 'year': + let oneYearBefore = new Date(currentHour.getTime() - oneYearInMs) + dbQuery.where.hour = { [models.Sequelize.Op.gte]: oneYearBefore } + break + case undefined: + break + default: + return errorResponseBadRequest('Invalid time parameter provided, use day/week/month/year or no parameter') + } + if (limit) { + dbQuery.limit = limit + } + + if (offset) { + dbQuery.offset = offset + } + + let listenCounts = await models.TrackListenCount.findAll(dbQuery) + let parsedListenCounts = [] + let seenTrackIds = [] + listenCounts.forEach((elem) => { + parsedListenCounts.push({ trackId: elem.trackId, listens: parseInt(elem.listens) }) + seenTrackIds.push(elem.trackId) + }) + + const seenIdSet = new Set(seenTrackIds) + if (idList && seenIdSet.size < idList.length) { + // For any tracks in the required id list that were not listened to in the last + // Populate empty listen counts + for (var i = 0; i < idList.length; i++) { + const id = parseInt(idList[i]) + // Add tracks only if not already present in parsedListenCounts + if (!seenIdSet.has(id)) { + parsedListenCounts.push({ trackId: id, listens: 0 }) + } + + // Exit if desired response limit has been met + if (limit && parsedListenCounts.length >= limit) { + break + } + } + } + + return parsedListenCounts +} + module.exports = function (app) { app.post('/tracks/:id/listen', handleResponse(async (req, res) => { const trackId = parseInt(req.params.id) @@ -162,11 +258,10 @@ module.exports = function (app) { app.post('/tracks/listens/:timeframe*?', handleResponse(async (req, res, next) => { let body = req.body let idList = body.track_ids - let limit = body.limit - let offset = body.offset let startTime = body.startTime let endTime = body.endTime let time = parseTimeframe(req.params.timeframe) + let { limit, offset } = getPaginationVars(body.limit, body.offset) let output = await getTrackListens( idList, time, @@ -179,12 +274,11 @@ module.exports = function (app) { })) app.get('/tracks/listens/:timeframe*?', handleResponse(async (req, res) => { - let limit = req.query.limit - let offset = req.query.offset let idList = req.query.id let startTime = req.query.start let endTime = req.query.end let time = parseTimeframe(req.params.timeframe) + let { limit, offset } = getPaginationVars(req.query.limit, req.query.offset) let output = await getTrackListens( idList, time, @@ -204,84 +298,38 @@ module.exports = function (app) { * -