diff --git a/src/logid/backend/hidpp10/Receiver.cpp b/src/logid/backend/hidpp10/Receiver.cpp index 9cefc760..1ae5b48f 100644 --- a/src/logid/backend/hidpp10/Receiver.cpp +++ b/src/logid/backend/hidpp10/Receiver.cpp @@ -18,6 +18,7 @@ #include #include +#include using namespace logid::backend::hidpp10; using namespace logid::backend; @@ -45,6 +46,8 @@ void Receiver::_receiverCheck() { Receiver::NotificationFlags Receiver::getNotificationFlags() { auto response = getRegister(EnableHidppNotifications, {}, hidpp::ReportType::Short); + if (response.size() < 2) + throw std::runtime_error("getNotificationFlags: short response"); NotificationFlags flags{}; flags.deviceBatteryStatus = response[0] & (1 << 4); @@ -74,6 +77,8 @@ void Receiver::enumerate() { ///TODO: Investigate usage uint8_t Receiver::getConnectionState(hidpp::DeviceIndex index) { auto response = getRegister(ConnectionState, {index}, hidpp::ReportType::Short); + if (response.empty()) + throw std::runtime_error("getConnectionState: empty response"); return response[0]; } @@ -159,8 +164,11 @@ std::map Receiver::getDeviceActivity() { auto response = getRegister(DeviceActivity, {}, hidpp::ReportType::Long); std::map device_activity; - for (uint8_t i = hidpp::WirelessDevice1; i <= hidpp::WirelessDevice6; i++) + for (uint8_t i = hidpp::WirelessDevice1; i <= hidpp::WirelessDevice6; i++) { + if (i >= response.size()) + break; device_activity[static_cast(i)] = response[i]; + } return device_activity; } @@ -178,6 +186,8 @@ Receiver::getPairingInfo(hidpp::DeviceIndex index) { struct PairingInfo info{}; if (_is_bolt) { + if (response.size() < 4) + throw std::runtime_error("getPairingInfo: short bolt response"); info = { .destinationId = 0, // no longer given? .reportInterval = 0, // no longer given? @@ -185,6 +195,8 @@ Receiver::getPairingInfo(hidpp::DeviceIndex index) { .deviceType = static_cast(response[1]) }; } else { + if (response.size() < 8) + throw std::runtime_error("getPairingInfo: short response"); info = { .destinationId = response[1], .reportInterval = response[2], @@ -208,6 +220,10 @@ Receiver::getExtendedPairingInfo(hidpp::DeviceIndex index) { auto response = getRegister(PairingInfo, request, hidpp::ReportType::Long); + const std::size_t required = static_cast(psl_offset) + 1; + if (response.size() < required) + throw std::runtime_error("getExtendedPairingInfo: short response"); + ExtendedPairingInfo info{}; info.serialNumber = 0; @@ -274,12 +290,18 @@ std::string Receiver::getDeviceName(hidpp::DeviceIndex index) { } hidpp::DeviceIndex Receiver::deviceDisconnectionEvent(const hidpp::Report& report) { - assert(report.subId() == DeviceDisconnection); + if (report.subId() != DeviceDisconnection) + throw std::runtime_error("deviceDisconnectionEvent: bad subId"); return report.deviceIndex(); } hidpp::DeviceConnectionEvent Receiver::deviceConnectionEvent(const hidpp::Report& report) { - assert(report.subId() == DeviceConnection); + if (report.subId() != DeviceConnection) + throw std::runtime_error("deviceConnectionEvent: bad subId"); + + const auto param_size = static_cast(report.paramEnd() - report.paramBegin()); + if (param_size < 3) + throw std::runtime_error("deviceConnectionEvent: short report"); auto data = report.paramBegin(); @@ -299,7 +321,14 @@ hidpp::DeviceConnectionEvent Receiver::deviceConnectionEvent(const hidpp::Report bool Receiver::fillDeviceDiscoveryEvent(DeviceDiscoveryEvent& event, const hidpp::Report& report) { - assert(report.subId() == DeviceDiscovered); + if (report.subId() != DeviceDiscovered) + throw std::runtime_error("fillDeviceDiscoveryEvent: bad subId"); + if (report.type() != hidpp::Report::Type::Long) + throw std::runtime_error("fillDeviceDiscoveryEvent: short report"); + + const auto param_size = static_cast(report.paramEnd() - report.paramBegin()); + if (param_size < hidpp::LongParamLength) + throw std::runtime_error("fillDeviceDiscoveryEvent: truncated payload"); auto data = report.paramBegin(); @@ -343,7 +372,11 @@ bool Receiver::fillDeviceDiscoveryEvent(DeviceDiscoveryEvent& event, } PairStatusEvent Receiver::pairStatusEvent(const hidpp::Report& report) { - assert(report.subId() == PairStatus); + if (report.subId() != PairStatus) + throw std::runtime_error("pairStatusEvent: bad subId"); + const auto param_size = static_cast(report.paramEnd() - report.paramBegin()); + if (param_size < 2) + throw std::runtime_error("pairStatusEvent: short report"); return { .pairing = (bool)(report.paramBegin()[0] & 1), @@ -352,7 +385,11 @@ PairStatusEvent Receiver::pairStatusEvent(const hidpp::Report& report) { } BoltPairStatusEvent Receiver::boltPairStatusEvent(const hidpp::Report& report) { - assert(report.subId() == BoltPairStatus); + if (report.subId() != BoltPairStatus) + throw std::runtime_error("boltPairStatusEvent: bad subId"); + const auto param_size = static_cast(report.paramEnd() - report.paramBegin()); + if (param_size < 2) + throw std::runtime_error("boltPairStatusEvent: short report"); return { .pairing = report.address() == 0, @@ -361,7 +398,11 @@ BoltPairStatusEvent Receiver::boltPairStatusEvent(const hidpp::Report& report) { } DiscoveryStatusEvent Receiver::discoveryStatusEvent(const hidpp::Report& report) { - assert(report.subId() == DiscoveryStatus); + if (report.subId() != DiscoveryStatus) + throw std::runtime_error("discoveryStatusEvent: bad subId"); + const auto param_size = static_cast(report.paramEnd() - report.paramBegin()); + if (param_size < 2) + throw std::runtime_error("discoveryStatusEvent: short report"); return { .discovering = report.address() == 0, @@ -370,7 +411,13 @@ DiscoveryStatusEvent Receiver::discoveryStatusEvent(const hidpp::Report& report) } std::string Receiver::passkeyEvent(const hidpp::Report& report) { - assert(report.subId() == PasskeyRequest); + if (report.subId() != PasskeyRequest) + throw std::runtime_error("passkeyEvent: bad subId"); + if (report.type() != hidpp::Report::Type::Long) + throw std::runtime_error("passkeyEvent: short report"); + const auto param_size = static_cast(report.paramEnd() - report.paramBegin()); + if (param_size < 6) + throw std::runtime_error("passkeyEvent: truncated payload"); return {report.paramBegin(), report.paramBegin() + 6}; }