From d80ec2833a8c8363c451173d1661d4ae63513cc3 Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:56:18 -0400 Subject: [PATCH 1/6] Fix managed users endpoints to be based on user_id alone --- .../queries/test_get_managed_users.py | 27 ++++------------ .../discovery-provider/src/api/v1/users.py | 19 ++++-------- .../src/queries/get_managed_users.py | 31 ++++++++++--------- .../src/utils/auth_middleware.py | 4 --- 4 files changed, 29 insertions(+), 52 deletions(-) diff --git a/packages/discovery-provider/integration_tests/queries/test_get_managed_users.py b/packages/discovery-provider/integration_tests/queries/test_get_managed_users.py index 23c063632d7..9678a0f430b 100644 --- a/packages/discovery-provider/integration_tests/queries/test_get_managed_users.py +++ b/packages/discovery-provider/integration_tests/queries/test_get_managed_users.py @@ -98,9 +98,7 @@ def test_get_managed_users_default(app): db = get_db() populate_mock_db(db, {"users": test_users, "grants": test_managed_user_grants}) - managed_users = get_managed_users_with_grants( - {"manager_wallet_address": "0x10", "current_user_id": 10} - ) + managed_users = get_managed_users_with_grants({"user_id": 10}) # return all non-revoked records by default assert len(managed_users) == 3, "Expected exactly 3 records" @@ -116,8 +114,7 @@ def test_get_managed_users_no_filters(app): managed_users = get_managed_users_with_grants( { - "manager_wallet_address": "0x10", - "current_user_id": 10, + "user_id": 10, "is_approved": None, "is_revoked": None, } @@ -146,9 +143,7 @@ def test_get_managed_users_grants_without_users(app): ) populate_mock_db(db, entities) - managed_users = get_managed_users_with_grants( - {"manager_wallet_address": "0x10", "current_user_id": 10} - ) + managed_users = get_managed_users_with_grants({"user_id": 10}) # return all non-revoked records by default assert len(managed_users) == 3, "Expected exactly 3 records" @@ -163,20 +158,10 @@ def test_get_managed_users_invalid_parameters(app): populate_mock_db(db, {"users": test_users, "grants": test_managed_user_grants}) try: - get_managed_users_with_grants( - {"manager_wallet_address": None, "current_user_id": 10} - ) - assert False, "Should have thrown an error for missing wallet address" + get_managed_users_with_grants({"user_id": None}) + assert False, "Should have thrown an error for missing user_id" except ValueError as e: - assert str(e) == "manager_wallet_address is required" - - try: - get_managed_users_with_grants( - {"manager_wallet_address": "0x10", "current_user_id": None} - ) - assert False, "Should have thrown an error for missing current user id" - except ValueError as e: - assert str(e) == "current_user_id is required" + assert str(e) == "user_id is required" # ### get_user_managers ### # diff --git a/packages/discovery-provider/src/api/v1/users.py b/packages/discovery-provider/src/api/v1/users.py index eeea2f3ac92..ae049036551 100644 --- a/packages/discovery-provider/src/api/v1/users.py +++ b/packages/discovery-provider/src/api/v1/users.py @@ -2076,17 +2076,12 @@ class ManagedUsers(Resource): 500: "Server error", }, ) - @auth_middleware(include_wallet=True, require_auth=True) + @auth_middleware(require_auth=True) @full_ns.marshal_with(managed_users_response) - def get(self, id, authed_user_id, authed_user_wallet): + def get(self, id, authed_user_id): user_id = decode_with_abort(id, full_ns) - check_authorized(user_id, authed_user_id) - - args = GetManagedUsersArgs( - manager_wallet_address=authed_user_wallet, current_user_id=user_id - ) - users = get_managed_users_with_grants(args) + users = get_managed_users_with_grants(GetManagedUsersArgs(user_id=user_id)) users = list(map(format_managed_user, users)) return success_response(users) @@ -2112,15 +2107,13 @@ class Managers(Resource): 500: "Server error", }, ) - @auth_middleware(include_wallet=True, require_auth=True) + @auth_middleware(require_auth=True) @full_ns.marshal_with(managers_response) - def get(self, id, authed_user_id, authed_user_wallet): + def get(self, id, authed_user_id): user_id = decode_with_abort(id, full_ns) check_authorized(user_id, authed_user_id) - args = GetUserManagersArgs( - manager_wallet_address=authed_user_wallet, user_id=user_id - ) + args = GetUserManagersArgs(user_id=user_id) managers = get_user_managers_with_grants(args) managers = list(map(format_user_manager, managers)) diff --git a/packages/discovery-provider/src/queries/get_managed_users.py b/packages/discovery-provider/src/queries/get_managed_users.py index b7d170e63ac..5bba0f70189 100644 --- a/packages/discovery-provider/src/queries/get_managed_users.py +++ b/packages/discovery-provider/src/queries/get_managed_users.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, TypedDict from src.models.grants.grant import Grant +from src.models.users.user import User from src.queries.get_unpopulated_users import ( get_unpopulated_users, get_unpopulated_users_by_wallet, @@ -14,8 +15,7 @@ class GetManagedUsersArgs(TypedDict): - manager_wallet_address: str - current_user_id: int + user_id: int is_approved: Optional[bool] is_revoked: Optional[bool] @@ -115,7 +115,7 @@ def get_managed_users_with_grants(args: GetManagedUsersArgs) -> List[Dict]: Returns users managed by the given wallet address Args: - manager_wallet_address: str wallet address of the manager + user_id: Id of the manager is_approved: Optional[bool] If set, filters by approval status is_revoked: Optional[bool] If set, filters by revocation status, defaults to False @@ -124,17 +124,17 @@ def get_managed_users_with_grants(args: GetManagedUsersArgs) -> List[Dict]: """ is_approved = args.get("is_approved", None) is_revoked = args.get("is_revoked", False) - current_user_id = args.get("current_user_id") - grantee_address = args.get("manager_wallet_address") - if grantee_address is None: - raise ValueError("manager_wallet_address is required") - if current_user_id is None: - raise ValueError("current_user_id is required") + user_id = args.get("user_id") + if user_id is None: + raise ValueError("user_id is required") db = db_session.get_db_read_replica() with db.scoped_session() as session: - query = session.query(Grant).filter( - Grant.grantee_address == grantee_address, Grant.is_current == True + query = ( + session.query(User.user_id, Grant) + .join(Grant, User.wallet == Grant.grantee_address) + .filter(User.user_id == user_id) + .filter(Grant.is_current == True) ) if is_approved is not None: @@ -142,13 +142,16 @@ def get_managed_users_with_grants(args: GetManagedUsersArgs) -> List[Dict]: if is_revoked is not None: query = query.filter(Grant.is_revoked == is_revoked) - grants = query.all() - if len(grants) == 0: + results = query.all() + if len(results) == 0: return [] + grants = [grant for [_, grant] in results] user_ids = [grant.user_id for grant in grants] users = get_unpopulated_users(session, user_ids) - users = populate_user_metadata(session, user_ids, users, current_user_id) + users = populate_user_metadata( + session, user_ids, users, current_user_id=user_id + ) grants = query_result_to_list(grants) diff --git a/packages/discovery-provider/src/utils/auth_middleware.py b/packages/discovery-provider/src/utils/auth_middleware.py index b6f8e8bf6fc..7550c7a7dc6 100644 --- a/packages/discovery-provider/src/utils/auth_middleware.py +++ b/packages/discovery-provider/src/utils/auth_middleware.py @@ -17,8 +17,6 @@ def auth_middleware( parser: reqparse.RequestParser = None, - # Include the wallet in the kwargs for the wrapped function - include_wallet: bool = False, # If True, user must be authenticated to access this route, will abort with 401 if no user is found in headers. require_auth: bool = False, ): @@ -109,8 +107,6 @@ def wrapper(*args, **kwargs): abort(401, "You must be logged in to make this request.") kwargs["authed_user_id"] = authed_user_id - if include_wallet: - kwargs["authed_user_wallet"] = wallet_lower return func(*args, **kwargs) From a7215c1c6f4b313dbbaafea86dee6f51f99d66e0 Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:26:47 -0400 Subject: [PATCH 2/6] fix track history endpoint to allow managers --- .../test_get_user_listening_history.py | 33 ------------------- .../discovery-provider/src/api/v1/users.py | 26 ++++++++++----- .../src/queries/get_user_listening_history.py | 15 +++------ .../store/pages/history/lineups/sagas.ts | 17 ++++++---- 4 files changed, 32 insertions(+), 59 deletions(-) diff --git a/packages/discovery-provider/integration_tests/queries/test_get_user_listening_history.py b/packages/discovery-provider/integration_tests/queries/test_get_user_listening_history.py index c9568b3daee..a10295a1386 100644 --- a/packages/discovery-provider/integration_tests/queries/test_get_user_listening_history.py +++ b/packages/discovery-provider/integration_tests/queries/test_get_user_listening_history.py @@ -66,7 +66,6 @@ def test_get_user_listening_history_multiple_plays(app): session, GetUserListeningHistoryArgs( user_id=1, - current_user_id=1, limit=10, offset=0, query=None, @@ -116,7 +115,6 @@ def test_get_user_listening_history_no_plays(app): session, GetUserListeningHistoryArgs( user_id=3, - current_user_id=3, limit=10, offset=0, query=None, @@ -142,7 +140,6 @@ def test_get_user_listening_history_single_play(app): session, GetUserListeningHistoryArgs( user_id=2, - current_user_id=2, limit=10, offset=0, query=None, @@ -176,7 +173,6 @@ def test_get_user_listening_history_pagination(app): session, GetUserListeningHistoryArgs( user_id=1, - current_user_id=1, limit=1, offset=1, query=None, @@ -196,32 +192,6 @@ def test_get_user_listening_history_pagination(app): ) -def test_get_user_listening_history_mismatch_user_id(app): - """Tests a listening history with mismatching user ids""" - with app.app_context(): - db = get_db() - - populate_mock_db(db, test_entities) - - with db.scoped_session() as session: - _index_user_listening_history(session) - - track_history = _get_user_listening_history( - session, - GetUserListeningHistoryArgs( - user_id=1, - current_user_id=2, - limit=10, - offset=0, - query=None, - sort_method=None, - sort_direction=None, - ), - ) - - assert len(track_history) == 0 - - def test_get_user_listening_history_with_query(app): """Tests listening history from user with a query""" with app.app_context(): @@ -236,7 +206,6 @@ def test_get_user_listening_history_with_query(app): session, GetUserListeningHistoryArgs( user_id=1, - current_user_id=1, limit=10, offset=0, query="track 2", @@ -268,7 +237,6 @@ def test_get_user_listening_history_custom_sort(app): session, GetUserListeningHistoryArgs( user_id=1, - current_user_id=1, limit=10, offset=0, query=None, @@ -318,7 +286,6 @@ def test_get_user_listening_history_sort_by_most_listens(app): session, GetUserListeningHistoryArgs( user_id=1, - current_user_id=1, limit=10, offset=0, query=None, diff --git a/packages/discovery-provider/src/api/v1/users.py b/packages/discovery-provider/src/api/v1/users.py index ae049036551..9cebd5f2670 100644 --- a/packages/discovery-provider/src/api/v1/users.py +++ b/packages/discovery-provider/src/api/v1/users.py @@ -1071,9 +1071,8 @@ class TrackHistoryFull(Resource): def _get(self, id, authed_user_id): args = track_history_parser.parse_args() decoded_id = decode_with_abort(id, ns) - current_user_id = get_current_user_id(args) - if not current_user_id and decoded_id == authed_user_id: - current_user_id = authed_user_id + check_authorized(decoded_id, authed_user_id) + offset = format_offset(args) limit = format_limit(args) query = format_query(args) @@ -1081,7 +1080,6 @@ def _get(self, id, authed_user_id): sort_direction = format_sort_direction(args) get_tracks_args = GetUserListeningHistoryArgs( user_id=decoded_id, - current_user_id=current_user_id, limit=limit, offset=offset, query=query, @@ -1096,10 +1094,16 @@ def _get(self, id, authed_user_id): id="""Get User's Track History""", description="""Get the tracks the user recently listened to.""", params={"id": "A User ID"}, - responses={200: "Success", 400: "Bad request", 500: "Server error"}, + responses={ + 200: "Success", + 400: "Bad request", + 401: "Unauthorized", + 403: "Forbidden", + 500: "Server error", + }, ) @full_ns.expect(track_history_parser) - @auth_middleware(track_history_parser) + @auth_middleware(track_history_parser, require_auth=True) @full_ns.marshal_with(history_response_full) def get(self, id, authed_user_id=None): return self._get(id, authed_user_id) @@ -1111,10 +1115,16 @@ class TrackHistory(TrackHistoryFull): id="""Get User's Track History""", description="""Get the tracks the user recently listened to.""", params={"id": "A User ID"}, - responses={200: "Success", 400: "Bad request", 500: "Server error"}, + responses={ + 200: "Success", + 400: "Bad request", + 401: "Unauthorized", + 403: "Forbidden", + 500: "Server error", + }, ) @ns.expect(track_history_parser) - @auth_middleware(track_history_parser) + @auth_middleware(track_history_parser, require_auth=True) @ns.marshal_with(history_response) def get(self, id, authed_user_id): return super()._get(id, authed_user_id) diff --git a/packages/discovery-provider/src/queries/get_user_listening_history.py b/packages/discovery-provider/src/queries/get_user_listening_history.py index 6a2944c9794..43feb575c5a 100644 --- a/packages/discovery-provider/src/queries/get_user_listening_history.py +++ b/packages/discovery-provider/src/queries/get_user_listening_history.py @@ -25,9 +25,6 @@ class GetUserListeningHistoryArgs(TypedDict): # The current user logged in (from route param) user_id: int - # The current user logged in (from query arg) - current_user_id: int - # The maximum number of listens to return limit: int @@ -44,7 +41,7 @@ class GetUserListeningHistoryArgs(TypedDict): def get_user_listening_history(args: GetUserListeningHistoryArgs): """ - Returns a user's listening history + Returns a user's listening history. DOES NOT check authorization. Args: args: GetUserListeningHistoryArgs The parsed args from the request @@ -60,7 +57,6 @@ def get_user_listening_history(args: GetUserListeningHistoryArgs): def _get_user_listening_history(session: Session, args: GetUserListeningHistoryArgs): user_id = args["user_id"] - current_user_id = args["current_user_id"] limit = args["limit"] offset = args["offset"] query = args["query"] @@ -68,12 +64,9 @@ def _get_user_listening_history(session: Session, args: GetUserListeningHistoryA sort_direction = args["sort_direction"] sort_fn = desc if sort_direction == SortDirection.desc else asc - if user_id != current_user_id: - return [] - listening_history_results = ( session.query(UserListeningHistory.listening_history).filter( - UserListeningHistory.user_id == current_user_id + UserListeningHistory.user_id == user_id ) ).scalar() @@ -119,9 +112,9 @@ def _get_user_listening_history(session: Session, args: GetUserListeningHistoryA # bundle peripheral info into track results tracks = populate_track_metadata( - session, track_ids, tracks, current_user_id, track_has_aggregates=True + session, track_ids, tracks, current_user_id=user_id, track_has_aggregates=True ) - tracks = add_users_to_tracks(session, tracks, current_user_id) + tracks = add_users_to_tracks(session, tracks, current_user_id=user_id) for track in tracks: track[response_name_constants.activity_timestamp] = listen_dates[ diff --git a/packages/web/src/common/store/pages/history/lineups/sagas.ts b/packages/web/src/common/store/pages/history/lineups/sagas.ts index edb3ab5c03f..f99dc499794 100644 --- a/packages/web/src/common/store/pages/history/lineups/sagas.ts +++ b/packages/web/src/common/store/pages/history/lineups/sagas.ts @@ -1,15 +1,16 @@ -import { LineupEntry, Track, UserTrackMetadata } from '@audius/common/models' +import { + Id, + LineupEntry, + Track, + UserTrackMetadata +} from '@audius/common/models' import { makeActivity } from '@audius/common/services' import { accountSelectors, getContext, historyPageTracksLineupActions as tracksActions } from '@audius/common/store' -import { - decodeHashId, - encodeHashId, - removeNullable -} from '@audius/common/utils' +import { decodeHashId, removeNullable } from '@audius/common/utils' import { keyBy } from 'lodash' import { call, select } from 'typed-redux-saga' @@ -27,11 +28,13 @@ function* getHistoryTracks() { try { const currentUserId = yield* select(getUserId) if (!currentUserId) return [] + const hashedId = Id.parse(currentUserId) const activity = yield* call( [sdk.full.users, sdk.full.users.getUsersTrackHistory], { - id: encodeHashId(currentUserId), + id: hashedId, + userId: hashedId, limit: 100 } ) From e75a30f4c75f68cb750b117c04e91a79fddc13ed Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Mon, 3 Jun 2024 12:52:12 -0400 Subject: [PATCH 3/6] gen sdk --- .../libs/src/sdk/api/generated/default/apis/PlaylistsApi.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/libs/src/sdk/api/generated/default/apis/PlaylistsApi.ts b/packages/libs/src/sdk/api/generated/default/apis/PlaylistsApi.ts index cbd900fc297..7168f730601 100644 --- a/packages/libs/src/sdk/api/generated/default/apis/PlaylistsApi.ts +++ b/packages/libs/src/sdk/api/generated/default/apis/PlaylistsApi.ts @@ -145,7 +145,7 @@ export class PlaylistsApi extends runtime.BaseAPI { /** * @hidden - * Gets information necessary to access the playlist and what access the given user has. + * Gets the information necessary to access the playlist and what access the given user has. */ async getPlaylistAccessInfoRaw(params: GetPlaylistAccessInfoRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise> { if (params.playlistId === null || params.playlistId === undefined) { @@ -171,7 +171,7 @@ export class PlaylistsApi extends runtime.BaseAPI { } /** - * Gets information necessary to access the playlist and what access the given user has. + * Gets the information necessary to access the playlist and what access the given user has. */ async getPlaylistAccessInfo(params: GetPlaylistAccessInfoRequest, initOverrides?: RequestInit | runtime.InitOverrideFunction): Promise { const response = await this.getPlaylistAccessInfoRaw(params, initOverrides); From 719d78e27badbb2dcc914d032566959efe1069c4 Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:02:00 -0400 Subject: [PATCH 4/6] remove user_id from history endpoint --- packages/discovery-provider/src/api/v1/helpers.py | 2 +- packages/libs/src/sdk/api/generated/full/apis/UsersApi.ts | 5 ----- packages/web/src/common/store/pages/history/lineups/sagas.ts | 1 - 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/packages/discovery-provider/src/api/v1/helpers.py b/packages/discovery-provider/src/api/v1/helpers.py index f2fa40e3399..4bc89f95ef2 100644 --- a/packages/discovery-provider/src/api/v1/helpers.py +++ b/packages/discovery-provider/src/api/v1/helpers.py @@ -705,7 +705,7 @@ def __schema__(self): search_parser = reqparse.RequestParser(argument_class=DescriptiveArgument) search_parser.add_argument("query", required=True, description="The search query") -track_history_parser = pagination_with_current_user_parser.copy() +track_history_parser = pagination_parser.copy() track_history_parser.add_argument( "query", required=False, description="The filter query" ) diff --git a/packages/libs/src/sdk/api/generated/full/apis/UsersApi.ts b/packages/libs/src/sdk/api/generated/full/apis/UsersApi.ts index d6dd86e19ee..086f3de1ab5 100644 --- a/packages/libs/src/sdk/api/generated/full/apis/UsersApi.ts +++ b/packages/libs/src/sdk/api/generated/full/apis/UsersApi.ts @@ -369,7 +369,6 @@ export interface GetUsersTrackHistoryRequest { id: string; offset?: number; limit?: number; - userId?: string; query?: string; sortMethod?: GetUsersTrackHistorySortMethodEnum; sortDirection?: GetUsersTrackHistorySortDirectionEnum; @@ -2029,10 +2028,6 @@ export class UsersApi extends runtime.BaseAPI { queryParameters['limit'] = params.limit; } - if (params.userId !== undefined) { - queryParameters['user_id'] = params.userId; - } - if (params.query !== undefined) { queryParameters['query'] = params.query; } diff --git a/packages/web/src/common/store/pages/history/lineups/sagas.ts b/packages/web/src/common/store/pages/history/lineups/sagas.ts index f99dc499794..261f10c3e65 100644 --- a/packages/web/src/common/store/pages/history/lineups/sagas.ts +++ b/packages/web/src/common/store/pages/history/lineups/sagas.ts @@ -34,7 +34,6 @@ function* getHistoryTracks() { [sdk.full.users, sdk.full.users.getUsersTrackHistory], { id: hashedId, - userId: hashedId, limit: 100 } ) From c2d6424ab97ad0034ef03cd087e70e1dba10fe46 Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:10:54 -0400 Subject: [PATCH 5/6] add changeset --- .changeset/witty-donuts-sit.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/witty-donuts-sit.md diff --git a/.changeset/witty-donuts-sit.md b/.changeset/witty-donuts-sit.md new file mode 100644 index 00000000000..8b144beae0b --- /dev/null +++ b/.changeset/witty-donuts-sit.md @@ -0,0 +1,5 @@ +--- +'@audius/sdk': patch +--- + +Remove unused user_id parameter from track history endpoint From fa25afe1148657605088081a90644c79e45bc119 Mon Sep 17 00:00:00 2001 From: Randy Schott <1815175+schottra@users.noreply.github.com> Date: Tue, 4 Jun 2024 10:15:57 -0400 Subject: [PATCH 6/6] remove unnecessary doc string --- .../src/queries/get_user_listening_history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/discovery-provider/src/queries/get_user_listening_history.py b/packages/discovery-provider/src/queries/get_user_listening_history.py index 43feb575c5a..9de13230b73 100644 --- a/packages/discovery-provider/src/queries/get_user_listening_history.py +++ b/packages/discovery-provider/src/queries/get_user_listening_history.py @@ -41,7 +41,7 @@ class GetUserListeningHistoryArgs(TypedDict): def get_user_listening_history(args: GetUserListeningHistoryArgs): """ - Returns a user's listening history. DOES NOT check authorization. + Returns a user's listening history. Args: args: GetUserListeningHistoryArgs The parsed args from the request