TicketNo:#IB4OMX Description:新需求: 更新发现连接方法及AES加解密

Signed-off-by: LongestDistance <cdwango@isoftstone.com>
This commit is contained in:
LongestDistance 2024-11-15 16:31:25 +08:00
parent c831294373
commit 4068ae6a41
13 changed files with 540 additions and 393 deletions

View File

@ -0,0 +1,136 @@
/*
* Copyright (C) 2023-2024 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* Description: implement the cast source connect
* Author: chenggong
* Create: 2023-10-23
*/
#ifndef CAST_META_NODE_CONSTANT_H
#define CAST_META_NODE_CONSTANT_H
#include <string>
namespace OHOS {
namespace CastEngine {
namespace CastEngineService {
const std::string KEY_BIND_TARGET_ACTION = "action";
constexpr int ACTION_CONNECT_DEVICE = 0;
constexpr int ACTION_QUERY_P2P_IP = 1;
constexpr int ACTION_SEND_MESSAGE = 2;
constexpr int ACTION_FINISH_BINDING_STATUS = 3;
const std::string KEY_LOCAL_NETWORK_ID = "localNetworkId";
const std::string KEY_REMOTE_NETWORK_ID = "remoteNetworkId";
const std::string KEY_LOCAL_WIFI_IP = "localWifiIp";
const std::string KEY_REMOTE_WIFI_IP = "remoteWifiIp";
const std::string KEY_LOCAL_P2P_IP = "localP2PIp";
const std::string KEY_REMOTE_P2P_IP = "remoteP2PIp";
const std::string NETWORK_ID = "networkId";
const std::string KEY_TRANSFER_MODE = "transferMode";
const std::string KEY_SESSION_ID = "sessionId";
const std::string PORT_KEY = "port";
const std::string SOURCE_IP_KEY = "sourceIp";
const std::string SINK_IP_KEY = "sinkIp";
const std::string TYPE_SESSION_KEY = "sessionKey";
const std::string KEY_CONNECT_TYPE = "connectType";
const std::string CONNECT_TYPE_WIFI = "lan"; // Do not modify the string
const std::string CONNECT_TYPE_P2P = "p2p"; // Do not modify the string
const std::string KEY_PROJECTION_MODE = "projectionMode";
const std::string PROJECTION_MODE_MIRROR = "MIRROR";
const std::string PROJECTION_MODE_STREAM = "MEDIA_RESOURCE";
const std::string KEY_UNBIND_TARGET_CAUSE = "unbindCause";
constexpr const char *AIRSHARING_MODULE_NAME = "com.huawei.android.airsharing+CastPlusDiscoveryModule";
constexpr int CAUSE_AUTH_SUCCESS = 0;
constexpr int CAUSE_AUTH_FAILED = 1;
constexpr int CAUSE_DISCONNECT = 2;
inline constexpr int OK = 0;
inline constexpr int INVALID_VALUE = -1;
constexpr int AUTH_MODE_GENERIC = 1;
constexpr int AUTH_MODE_PWD = 2;
constexpr int CAST_SESSION_KEY_LENGTH = 16;
const std::string TIME_RECORD = "TimeRecord";
const std::string TOTAL_AUTH_TIME = "totalAuthTime";
const std::string TIME_RECORD_STRING = "timeRecordString";
const std::string VERSION_OH = "OH1.0";
const std::string PACKAGE_NAME_KEY = "packageName";
const std::string VERSION_NAME_KEY = "castpluskitVersionName";
const std::string PRODUCT_NAME_KEY = "productName";
const std::string CAST_PLUS_APP_VERSION_KEY = "castplusAppVersion";
const std::string PRODUCT_ID_KEY = "productId";
const std::string IS_VIRGIN_SINK = "isVirginSink";
const std::string ALLOWED_ALWAYS = "allowedAlways";
const std::string DEVICE_CAST_SOURCE = "deviceCastSource";
const std::string AUTH_MODE_KEY = "authMode";
const std::string VERSION_KEY = "version";
const std::string OPERATION_TYPE_KEY = "operType";
const std::string SEQUENCE_NUMBER = "sequenceNumber";
const std::string DATA_KEY = "data";
const std::string DEVICE_SALT_KEY = "salt";
const std::string CONNECT_TYPE_PRIORITY_KEY = "ConnectTypePriority";
const std::string TRIGGER_TYPE_KEY = "triggerType";
const std::string HANDSHAKE_RESULT_KEY = "handshakeResult";
const std::string DEVICE_IP_KEY = "deviceIp";
const std::string CURRENT_TIME_KEY = "currentTime";
const std::string KIT_AUTH_STATUS_KEY = "kitAuthStatus";
const std::string KIT_AUTH_POLICY_KEY = "kitAuthPolicy";
const std::string IS_GENERIC_TRUSTED_KEY = "isGenericTrusted";
const std::string IS_PWD_TRUSTED_KEY = "isPwdTrusted";
const std::string IS_SAME_ACCOUNT = "isSameAccount";
const std::string IS_SAME_FAMILY_GROUP = "isSameFamilyGroup";
const std::string IS_CONFIRMED_KEY = "isConfirmed";
const std::string DEVICE_ONLINE_INFO_KEY = "onlineInfo";
const std::string COORDINATION_CAPABILITY_KEY = "hasCoordinationCapability";
const std::string KIT_AUTH_CHOICE_KEY = "kitAuthChoice";
const std::string PHONE_MODEL_VERSION_KEY = "phoneModelVersion";
const std::string MODEL_NAME_KEY = "modelName";
const std::string CONSULT_RESULT = "consultResult";
const std::string SUB_REASON = "subReason";
const std::string SN_KEY = "serial_number";
const std::string CAPABILITY_INFO_KEY = "capabilityInfo";
// consult key
const std::string DEVICE_ID_KEY = "deviceId";
const std::string DEVICE_NAME_KEY = "deviceName";
const std::string REMOTE_DEVICE_NAME_KEY = "remoteDeviceName";
const std::string PROTOCOL_TYPE_KEY = "protocolType";
const std::string ACCOUNT_ID_KEY = "accountId";
const std::string USER_ID_KEY = "userId";
const std::string KIT_VER_PREFIX = "castpluskit";
// Prefix of third party castplus kit
const std::string KIT_VER_1X = "castpluskit 1";
// Prefix of HMOS2.x castplus kit
const std::string KIT_VER_2X = "castpluskit 2";
enum OperationStep : int32_t {
OPERATION_HANDSHAKE = 1,
OPERATION_AUTHENTICATE = 2,
OPERATION_CONSULT = 3,
};
enum DmAuthStatusExt : int32_t {
HANDSHAKE_RESPONSE = 1000,
};
} // namespace CastEngineService
} // namespace CastEngine
} // namespace OHOS
#endif

View File

@ -96,7 +96,6 @@ private:
ServiceStatus serviceStatus_{ ServiceStatus::DISCONNECTED };
int sessionCapacity_{ 0 };
std::map<int32_t, sptr<ICastSessionImpl>> sessionMap_;
std::atomic<int> sessionIndex_{ 0 };
std::unordered_map<pid_t, sptr<IRemoteObject::DeathRecipient>> deathRecipientMap_;
std::atomic<bool> hasServer_{ false };
std::atomic<bool> isUnloading_{ false };

View File

@ -560,13 +560,9 @@ int32_t CastSessionManagerService::SetSinkSessionCapacity(int sessionCapacity)
int32_t CastSessionManagerService::StartDiscovery(int protocols, std::vector<std::string> drmSchemes)
{
static_cast<void>(protocols);
CLOGI("StartDiscovery in, protocolType = %{public}d, drm shcheme size = %{public}zu", protocols, drmSchemes.size());
SharedRLock lock(mutex_);
if (!Permission::CheckPidPermission()) {
return ERR_NO_PERMISSION;
}
DiscoveryManager::GetInstance().StartDiscovery();
DiscoveryManager::GetInstance().StartDiscovery(protocols, drmSchemes);
return CAST_ENGINE_SUCCESS;
}

View File

@ -95,7 +95,7 @@ public:
void Init(std::shared_ptr<IDiscoveryManagerListener> listener);
void Deinit();
void StartDiscovery();
void StartDiscovery(int protocols, std::vector<std::string> drmSchemes);
void StopDiscovery();
bool StartAdvertise();
@ -137,13 +137,15 @@ private:
std::mutex mutex_;
int32_t uid_{ 0 };
int protocolType_ = 0;
bool isNotifyDevice_{ false };
int protocolType_;
std::vector<std::string> drmSchemes_;
std::shared_ptr<IDiscoveryManagerListener> listener_;
std::shared_ptr<EventRunner> eventRunner_;
std::shared_ptr<DiscoveryEventHandler> eventHandler_;
std::unordered_map<CastInnerRemoteDevice, int> remoteDeviceMap_;
int32_t scanCount_;
bool hasStartDiscovery_ = false;
};
} // namespace CastEngineService
} // namespace CastEngine

View File

@ -321,16 +321,16 @@ RemoteDeviceState CastDeviceDataManager::GetDeviceState(const std::string &devic
return GetDeviceStateLocked(deviceId);
}
bool CastDeviceDataManager::IsDeviceConnecting(const std::string &deviceId)
{
return GetDeviceState(deviceId) == RemoteDeviceState::CONNECTING;
}
bool CastDeviceDataManager::IsDeviceConnected(const std::string &deviceId)
{
return GetDeviceState(deviceId) == RemoteDeviceState::CONNECTED;
}
bool CastDeviceDataManager::IsDeviceConnecting(const std::string &deviceId)
{
return GetDeviceState(deviceId) == RemoteDeviceState::CONNECTING;
}
bool CastDeviceDataManager::IsDeviceUsed(const std::string &deviceId)
{
auto state = GetDeviceState(deviceId);

View File

@ -106,8 +106,7 @@ SubDeviceType ConvertSubDeviceType(uint16_t deviceTypeId)
const std::string AUTH_VERSION_KEY = "authVersion";
const std::string AUTH_VERSION_1 = "1.0";
const std::string AUTH_VERSION_2 = "2.0";
constexpr int THIRD_TV = 0x2E;
const std::string AUTH_VERSION_3 = "DM";
const std::string KEY_BIND_TARGET_ACTION = "action";
constexpr int ACTION_CONNECT_DEVICE = 0;
@ -118,12 +117,6 @@ const std::string KEY_LOCAL_P2P_IP = "localP2PIp";
const std::string KEY_REMOTE_P2P_IP = "remoteP2PIp";
const std::string NETWORK_ID = "networkId";
constexpr static int SECOND_BYTE_OFFSET = 8;
constexpr static int THIRD_BYTE_OFFSET = 16;
constexpr static int FOURTH_BYTE_OFFSET = 24;
constexpr static int INT_FOUR = 4;
void DeviceDiscoveryWriteWrap(const std::string& funcName, const std::string& puid)
{
HiSysEventWriteWrap(funcName, {
@ -151,20 +144,6 @@ void EstablishConsultWriteWrap(const std::string& funcName, int sceneType, const
{"PEER_UDID", puid}});
}
void DeviceAuthWriteWrap(const std::string& funcName, int sceneType, const std::string& puid)
{
HiSysEventWriteWrap(funcName, {
{"BIZ_SCENE", sceneType},
{"BIZ_STATE", static_cast<int32_t>(BIZStateType::BIZ_STATE_BEGIN)},
{"BIZ_STAGE", static_cast<int32_t>(BIZSceneStage::DEVICE_AUTHENTICATION)},
{"STAGE_RES", static_cast<int32_t>(StageResType::STAGE_RES_SUCCESS)},
{"ERROR_CODE", CAST_RADAR_SUCCESS}}, {
{"TO_CALL_PKG", DEVICE_MANAGER_NAME},
{"LOCAL_SESS_NAME", ""},
{"PEER_SESS_NAME", ""},
{"PEER_UDID", puid}});
}
} // namespace
namespace SoftBus {
@ -289,6 +268,11 @@ int BindSocket(int32_t socketId, const ProtocolType &protocolType, bool isSingle
return SOFTBUS_OK;
}
} // namespace SoftBus
constexpr static int SECOND_BYTE_OFFSET = 8;
constexpr static int THIRD_BYTE_OFFSET = 16;
constexpr static int FOURTH_BYTE_OFFSET = 24;
constexpr static int INT_FOUR = 4;
/*
* auth success
@ -312,7 +296,7 @@ const std::string USER_ID_KEY = "userId";
* User's unusual action or other event scenarios could cause changing of STATE or RESULT which delivered
* by DM.
*/
const std::map<int32_t, int32_t> CastBindTargetCallback::RESULT_REASON_MAP = {
const std::map<int, int32_t> CastBindTargetCallback::RESULT_REASON_MAP = {
// SINK peer click distrust button during 3-state authentication.
{ ERR_DM_AUTH_PEER_REJECT, REASON_DISTRUST_BY_SINK },
// SINK peer click cancel button during pin code inputting.
@ -325,7 +309,7 @@ const std::map<int32_t, int32_t> CastBindTargetCallback::RESULT_REASON_MAP = {
{ STOP_BIND, REASON_STOP_BIND_BY_SOURCE }
};
const std::map<int32_t, int32_t> CastBindTargetCallback::STATUS_REASON_MAP = {
const std::map<int, int32_t> CastBindTargetCallback::STATUS_REASON_MAP = {
// DEFAULT event
{ DmAuthStatus::STATUS_DM_AUTH_DEFAULT, REASON_DEFAULT },
// Sink peer click trust during 3-state authentication.
@ -494,29 +478,30 @@ void ConnectionManager::OnConsultDataReceived(int transportId, const void *data,
CLOGE("Failed to get DmDeviceInfo");
return;
}
int castSessionId = INVALID_ID;
constexpr int32_t sleepTimeMs = 50;
constexpr int32_t retryTimes = 20;
int32_t retryTime = 0;
while (castSessionId == INVALID_ID) {
if (castSessionId != INVALID_ID || retryTime > retryTimes) {
break;
}
int castSessionId = GetCastSessionId(transportId);
for (int32_t retryTime = 1; castSessionId == INVALID_ID && retryTime < retryTimes; retryTime++) {
CLOGD("Retry for the %d(th) time after sleeping %dms", retryTime, sleepTimeMs);
std::this_thread::sleep_for(std::chrono::milliseconds(sleepTimeMs));
castSessionId = GetCastSessionId(transportId);
retryTime++;
}
if (castSessionId == INVALID_ID) {
CLOGE("session id invalid");
CLOGE("Invalid CastSessionId!");
return;
}
CLOGI("protocolType is %d", device->protocolType);
if (device->protocolType == ProtocolType::CAST_PLUS_STREAM) {
SetSessionProtocolType(castSessionId, device->protocolType);
}
if (!listener_) {
CLOGE("Detect absence of listener_.");
return;
}
listener_->ReportSessionCreate(castSessionId);
device->localCastSessionId = castSessionId;
if (!CastDeviceDataManager::GetInstance().AddDevice(*device, dmDevice)) {
return;
}
@ -535,13 +520,12 @@ void ConnectionManager::OnConsultDataReceived(int transportId, const void *data,
int ConnectionManager::GetCastSessionId(int transportId)
{
std::lock_guard<std::mutex> lock(mutex_);
for (const auto &element : transIdToCastSessionIdMap_) {
if (element.first == transportId) {
return element.second;
}
if (transIdToCastSessionIdMap_.count(transportId) == 1) {
return transIdToCastSessionIdMap_[transportId];
} else {
CLOGE("Invalid transport id:%{public}d", transportId);
return INVALID_ID;
}
CLOGE("Invalid transport id:%{public}d", transportId);
return INVALID_ID;
}
bool ConnectionManager::OnConsultSessionOpened(int transportId, bool isSource)
@ -655,12 +639,10 @@ bool ConnectionManager::ConnectDevice(const CastInnerRemoteDevice &dev, const Pr
DeviceDiscoveryWriteWrap(__func__, GetAnonymousDeviceID(dev.deviceId));
auto &deviceId = dev.deviceId;
CLOGI("ConnectDevice in, %s", deviceId.c_str());
if (CastDeviceDataManager::GetInstance().IsDeviceUsed(deviceId)) {
CLOGD("Device: %s is used.", deviceId.c_str());
return true;
}
CLOGI("deviceId %{public}s, protocolType %{public}d, capabilityInfo %{public}d, wifiIp %{public}s, "
"bleMac %{public}s, isLeagacy %{public}d, isFresh wifi %{public}d, ble %{public}d",
Utils::Mask(deviceId).c_str(), protocolType, dev.capabilityInfo, Utils::Mask(dev.wifiIp).c_str(),
Utils::Mask(dev.bleMac).c_str(), dev.isLeagacy, dev.isWifiFresh, dev.isBleFresh);
if (!UpdateDeviceState(deviceId, RemoteDeviceState::CONNECTING)) {
CLOGE("Device(%s) is missing", deviceId.c_str());
@ -672,7 +654,7 @@ bool ConnectionManager::ConnectDevice(const CastInnerRemoteDevice &dev, const Pr
if (IsNeedDiscoveryDevice(dev)) {
CLOGI("need discovery device");
DiscoveryManager::GetInstance().StartDiscovery();
DiscoveryManager::GetInstance().StartDiscovery(static_cast<int>(protocolType), {});
std::thread([this, dev]() {
Utils::SetThreadName("ConnectTargetDevice");
WaitAndConnectTargetDevice(dev);
@ -684,7 +666,7 @@ bool ConnectionManager::ConnectDevice(const CastInnerRemoteDevice &dev, const Pr
std::string networkId;
if (IsDeviceTrusted(dev.deviceId, networkId) && IsSingle(dev) && SourceCheckConnectAccess(networkId)) {
DeviceAuthWriteWrap(__func__, GetBIZSceneType(GetProtocolType()), GetAnonymousDeviceID(dev.deviceId));
NotifyListenerToLoadSinkSA(networkId);
if (!CastDeviceDataManager::GetInstance().SetDeviceNetworkId(deviceId, networkId) ||
!OpenConsultSession(dev)) {
(void)UpdateDeviceState(deviceId, RemoteDeviceState::FOUND);
@ -697,12 +679,13 @@ bool ConnectionManager::ConnectDevice(const CastInnerRemoteDevice &dev, const Pr
(void)UpdateDeviceState(deviceId, RemoteDeviceState::FOUND);
return false;
}
std::unique_lock<std::mutex> lock(mutex_);
if (isBindTargetMap_.find(deviceId) != isBindTargetMap_.end()) {
isBindTargetMap_[deviceId] = true;
} else {
isBindTargetMap_.insert({ deviceId, true });
}
CLOGI("ConnectDevice out, %s", deviceId.c_str());
CLOGI("ConnectDevice out, %{public}s", Utils::Mask(deviceId).c_str());
return true;
}
@ -710,6 +693,8 @@ void ConnectionManager::DisconnectDevice(const std::string &deviceId)
{
CLOGI("DisconnectDevice in, deviceId %{public}s", Utils::Mask(deviceId).c_str());
std::unique_lock<std::mutex> lock(mutex_);
connectingDeviceId_ = "";
DiscoveryManager::GetInstance().StopDiscovery();
if (!CastDeviceDataManager::GetInstance().IsDeviceUsed(deviceId)) {
CLOGE("Device(%s) is not used, remove it", deviceId.c_str());
@ -717,7 +702,11 @@ void ConnectionManager::DisconnectDevice(const std::string &deviceId)
return;
}
protocolType_ = ProtocolType::CAST_PLUS_MIRROR;
lock.unlock();
UpdateDeviceState(deviceId, RemoteDeviceState::FOUND);
DestroyConsulationSession(deviceId);
CastDeviceDataManager::GetInstance().GetDeviceByDeviceId(deviceId);
auto isActiveAuth = CastDeviceDataManager::GetInstance().GetDeviceIsActiveAuth(deviceId);
if (isActiveAuth == std::nullopt) {
return;
@ -778,12 +767,17 @@ int32_t ConnectionManager::GetLocalDeviceInfo(CastLocalDevice &device)
void ConnectionManager::NotifySessionIsReady(int transportId)
{
if (!listener_) {
CLOGE("Detect absence of listener_.");
return;
}
int castSessionId = listener_->NotifySessionIsReady();
if (castSessionId == INVALID_ID) {
CLOGE("sessionId is invalid");
return;
}
CLOGD("Update cast session id map: %d: %d", transportId, castSessionId);
std::lock_guard<std::mutex> lock(mutex_);
transIdToCastSessionIdMap_.insert({ transportId, castSessionId });
}
@ -791,11 +785,11 @@ void ConnectionManager::NotifySessionIsReady(int transportId)
void ConnectionManager::NotifyDeviceIsOffline(const std::string &deviceId)
{
CLOGI("NotifyDeviceIsOffline in");
std::lock_guard<std::mutex> lock(mutex_);
if (!listener_) {
auto listener = GetListener();
if (!listener) {
return;
}
listener_->NotifyDeviceIsOffline(deviceId);
listener->NotifyDeviceIsOffline(deviceId);
}
bool ConnectionManager::NotifyConnectStage(const CastInnerRemoteDevice &device, int result, int32_t reasonCode)
@ -972,26 +966,12 @@ bool ConnectionManager::BindTarget(const CastInnerRemoteDevice &dev)
BuildBindParam(dev, bindParam);
int ret = DeviceManager::GetInstance().BindTarget(PKG_NAME, targetId, bindParam,
std::make_shared<CastBindTargetCallback>());
if (ret == ERR_DM_AUTH_BUSINESS_BUSY) {
CLOGE("bind fail, target is binding %d", ret);
auto networkId = CastDeviceDataManager::GetInstance().GetDeviceNetworkId(dev.deviceId);
PeerTargetId targetId = {
.deviceId = *networkId,
};
std::map<std::string, std::string> unbindParam{};
if (!CastDeviceDataManager::GetInstance().IsDoubleFrameDevice(dev.deviceId)) {
DeviceManager::GetInstance().UnbindTarget(PKG_NAME, targetId, unbindParam, nullptr);
} else {
unbindParam.insert(
std::pair<std::string, std::string>(PARAM_KEY_META_TYPE, std::to_string(5)));
DeviceManager::GetInstance().UnbindTarget(PKG_NAME, targetId, unbindParam, nullptr);
}
return false;
}
if (ret != DM_OK) {
CLOGE("ConnectDevice BindTarget fail, ret = %{public}d)", ret);
CastEngineDfx::WriteErrorEvent(AUTHENTICATE_DEVICE_FAIL);
if (ret == ERR_DM_AUTH_BUSINESS_BUSY) {
DeviceManager::GetInstance().UnbindTarget(
PKG_NAME, targetId, bindParam, std::make_shared<CastUnBindTargetCallback>());
}
return false;
}
@ -1132,12 +1112,15 @@ void ConnectionManager::EncryptPort(int port, const uint8_t *sessionKey, json &b
int portArraySize = 4;
ConstPacketData inputData = { portArray.get(), portArraySize };
uint8_t encryptedPort[portArraySize + EncryptDecrypt::AES_IV_LEN];
PacketData outputData = { encryptedPort, 0 };
EncryptDecrypt::GetInstance().EncryptData(EncryptDecrypt::CTR_CODE, sessionKey, SESSION_KEY_LENGTH, inputData,
outputData);
CLOGD("encrypt result is %d ", outputData.length);
std::string encryptedPortLatin1(reinterpret_cast<const char *>(outputData.data), outputData.length);
int encryptedDataLen = 0;
auto encryptedData = EncryptDecrypt::GetInstance().EncryptData(EncryptDecrypt::CTR_CODE, { sessionKey,
SESSION_KEY_LENGTH }, inputData, encryptedDataLen);
if (!encryptedData) {
CLOGE("encrypt error");
return;
}
CLOGD("encrypt result is %d ", encryptedDataLen);
std::string encryptedPortLatin1(reinterpret_cast<const char *>(encryptedData.get()), encryptedDataLen);
std::string encryptedPortUtf8 = convLatin1ToUTF8(encryptedPortLatin1);
body[PORT_KEY] = encryptedPortUtf8;
}
@ -1148,11 +1131,15 @@ void ConnectionManager::EncryptIp(const std::string &ip, const std::string &key,
return;
}
ConstPacketData inputData = { reinterpret_cast<const uint8_t *>(ip.c_str()), ip.size() };
uint8_t encrypted[ip.size() + EncryptDecrypt::AES_IV_LEN];
PacketData outputData = { encrypted, 0 };
EncryptDecrypt::GetInstance().EncryptData(EncryptDecrypt::CTR_CODE, sessionKey, SESSION_KEY_LENGTH, inputData,
outputData);
for (int i = 0; i < outputData.length; i++) {
int encryptedDataLen = 0;
auto encryptedData = EncryptDecrypt::GetInstance().EncryptData(EncryptDecrypt::CTR_CODE, { sessionKey,
SESSION_KEY_LENGTH }, inputData, encryptedDataLen);
if (!encryptedData) {
CLOGE("encrypt error");
return;
}
uint8_t *encrypted = encryptedData.get();
for (int i = 0; i < encryptedDataLen; i++) {
body[key].push_back(encrypted[i]);
}
CLOGI("encrypt %s finish", key.c_str());
@ -1160,12 +1147,13 @@ void ConnectionManager::EncryptIp(const std::string &ip, const std::string &key,
std::unique_ptr<uint8_t[]> ConnectionManager::intToByteArray(int32_t num)
{
unsigned int number = static_cast<unsigned int>(num);
std::unique_ptr<uint8_t[]> result = std::make_unique<uint8_t[]>(INT_FOUR);
int i = 0;
result[i] = (num >> FOURTH_BYTE_OFFSET) & 0xFF;
result[++i] = (num >> THIRD_BYTE_OFFSET) & 0xFF;
result[++i] = (num >> SECOND_BYTE_OFFSET) & 0xFF;
result[++i] = num & 0xFF;
unsigned int i = 0;
result[i] = (number >> FOURTH_BYTE_OFFSET) & 0xFF;
result[++i] = (number >> THIRD_BYTE_OFFSET) & 0xFF;
result[++i] = (number >> SECOND_BYTE_OFFSET) & 0xFF;
result[++i] = number & 0xFF;
return result;
}
@ -1401,20 +1389,20 @@ void ConnectionManager::ResetListener()
int32_t ConnectionManager::GetSessionProtocolType(int sessionId, ProtocolType &protocolType)
{
std::lock_guard<std::mutex> lock(mutex_);
if (!listener_) {
auto listener = GetListener();
if (!listener) {
return CAST_ENGINE_ERROR;
}
return listener_->GetSessionProtocolType(sessionId, protocolType);
return listener->GetSessionProtocolType(sessionId, protocolType);
}
int32_t ConnectionManager::SetSessionProtocolType(int sessionId, ProtocolType protocolType)
{
std::lock_guard<std::mutex> lock(mutex_);
if (!listener_) {
auto listener = GetListener();
if (!listener) {
return CAST_ENGINE_ERROR;
}
return listener_->SetSessionProtocolType(sessionId, protocolType);
return listener->SetSessionProtocolType(sessionId, protocolType);
}
void ConnectionManager::SendConsultInfo(const std::string &deviceId, int port)
@ -1622,39 +1610,32 @@ void CastDeviceStateCallback::OnDeviceOnline(const DmDeviceInfo &deviceInfo)
void CastDeviceStateCallback::OnDeviceOffline(const DmDeviceInfo &deviceInfo)
{
CLOGI("device(%s) is offline", deviceInfo.deviceId);
ConnectionManager::GetInstance().NotifyDeviceIsOffline(deviceInfo.deviceId);
CLOGI("device(%{public}s) is offline", Utils::Mask(deviceInfo.deviceId).c_str());
}
void CastDeviceStateCallback::OnDeviceChanged(const DmDeviceInfo &deviceInfo)
{
CLOGI("device(%s) is changed", deviceInfo.deviceId);
CLOGI("device(%{public}s) is changed", Utils::Mask(deviceInfo.deviceId).c_str());
}
void CastDeviceStateCallback::OnDeviceReady(const DmDeviceInfo &deviceInfo)
{
CLOGI("device(%s) is ready", deviceInfo.deviceId);
CLOGI("device(%{public}s) is ready", Utils::Mask(deviceInfo.deviceId).c_str());
}
bool ConnectionManager::IsSingle(const CastInnerRemoteDevice &device)
{
if (device.deviceTypeId == THIRD_TV) {
return false;
}
if (device.customData.empty() && device.wifiPort == 0 && device.bleMac.empty()) {
if (device.authVersion == AUTH_VERSION_3) {
CLOGI("Is huawei single device");
return true;
}
if (device.customData.empty()) {
return device.wifiPort != 0 || !device.bleMac.empty();
}
return false;
}
bool ConnectionManager::IsHuaweiDevice(const CastInnerRemoteDevice &device)
{
if (!device.customData.empty()) {
if (device.authVersion == AUTH_VERSION_2) {
CLOGI("Is huawei device");
return true;
}
return false;
@ -1662,10 +1643,12 @@ bool ConnectionManager::IsHuaweiDevice(const CastInnerRemoteDevice &device)
bool ConnectionManager::IsThirdDevice(const CastInnerRemoteDevice &device)
{
if (device.deviceTypeId == THIRD_TV) {
if (device.authVersion == AUTH_VERSION_1) {
CLOGI("Is third device");
return true;
}
return device.bleMac.empty() && device.wifiPort == 0;
return false;
}
bool ConnectionManager::IsBindTarget(std::string deviceId)

View File

@ -138,7 +138,66 @@ void DiscoveryManager::Deinit()
StopDiscovery();
ResetListener();
DeviceManager::GetInstance().UnInitDeviceManager(PKG_NAME);
eventRunner_->Stop();
}
void DiscoveryManager::StartDiscovery(int protocols, std::vector<std::string> drmSchemes)
{
HiSysEventWriteWrap(__func__, {
{"BIZ_SCENE", static_cast<int32_t>(BIZSceneType::DEVICE_DISCOVERY)},
{"BIZ_STATE", static_cast<int32_t>(BIZStateType::BIZ_STATE_BEGIN)},
{"BIZ_STAGE", static_cast<int32_t>(BIZSceneStage::START_DISCOVERY)},
{"STAGE_RES", static_cast<int32_t>(StageResType::STAGE_RES_IDLE)},
{"ERROR_CODE", CAST_RADAR_SUCCESS}}, {
{"TO_CALL_PKG", ""},
{"LOCAL_SESS_NAME", ""},
{"PEER_SESS_NAME", ""},
{"PEER_UDID", ""}});
CLOGI("StartDiscovery in");
protocolType_ = protocols;
drmSchemes_ = drmSchemes;
CastLocalDevice localDevice;
ConnectionManager::GetInstance().GetLocalDeviceInfo(localDevice);
std::lock_guard<std::mutex> lock(mutex_);
std::thread([this]() {
Utils::SetThreadName("DiscoveryEventRunner");
if (eventRunner_ != nullptr) {
eventRunner_->Run();
}
}).detach();
scanCount_ = 0;
for (auto it = remoteDeviceMap_.begin(); it != remoteDeviceMap_.end();) {
std::string deviceId = it->first.deviceId;
CastDeviceDataManager::GetInstance().SetDeviceNotFresh(deviceId);
it++;
}
remoteDeviceMap_.clear();
std::string connectDeviceId = ConnectionManager::GetInstance().GetConnectingDeviceId();
if (!connectDeviceId.empty()) {
auto device = CastDeviceDataManager::GetInstance().GetDeviceByDeviceId(connectDeviceId);
if (device != std::nullopt) {
device->deviceName = "";
remoteDeviceMap_[*device] = scanCount_ + 1;
}
}
uid_ = IPCSkeleton::GetCallingUid();
hasStartDiscovery_ = true;
eventHandler_->RemoveEvent(EVENT_START_DISCOVERY);
eventHandler_->SendEvent(EVENT_START_DISCOVERY);
CLOGI("StartDiscovery out");
}
void DiscoveryManager::StopDiscovery()
{
CLOGI("StopDiscovery in");
hasStartDiscovery_ = false;
SetDeviceNotFresh();
eventHandler_->RemoveAllEvents();
if (eventRunner_ != nullptr) {
eventRunner_->Stop();
}
StopDmDiscovery();
}
void DiscoveryManager::GetAndReportTrustedDevices()
@ -214,40 +273,6 @@ int DiscoveryManager::GetProtocolType() const
return protocolType_;
}
void DiscoveryManager::StartDiscovery()
{
HiSysEventWriteWrap(__func__, {
{"BIZ_SCENE", static_cast<int32_t>(BIZSceneType::DEVICE_DISCOVERY)},
{"BIZ_STATE", static_cast<int32_t>(BIZStateType::BIZ_STATE_BEGIN)},
{"BIZ_STAGE", static_cast<int32_t>(BIZSceneStage::START_DISCOVERY)},
{"STAGE_RES", static_cast<int32_t>(StageResType::STAGE_RES_IDLE)},
{"ERROR_CODE", CAST_RADAR_SUCCESS}}, {
{"TO_CALL_PKG", ""},
{"LOCAL_SESS_NAME", ""},
{"PEER_SESS_NAME", ""},
{"PEER_UDID", ""}});
CLOGI("StartDiscovery in");
scanCount_ = 0;
remoteDeviceMap_.clear();
std::lock_guard<std::mutex> lock(mutex_);
uid_ = IPCSkeleton::GetCallingUid();
eventHandler_->SendEvent(EVENT_START_DISCOVERY);
CLOGI("StartDiscovery out");
}
void DiscoveryManager::StopDiscovery()
{
CLOGI("StopDiscovery in");
SetDeviceNotFresh();
eventHandler_->RemoveAllEvents();
if (eventRunner_ != nullptr) {
eventRunner_->Stop();
}
StopDmDiscovery();
}
void DiscoveryManager::SetListener(std::shared_ptr<IDiscoveryManagerListener> listener)
{
std::lock_guard<std::mutex> lock(mutex_);
@ -346,7 +371,7 @@ void DiscoveryManager::NotifyDeviceIsFound(const CastInnerRemoteDevice &newDevic
}
devices.push_back(newDevice);
std::lock_guard<std::mutex> lock(mutex_);
isNotifyDevice_ = system::GetBoolParameter(NOTIFY_DEVICE_FOUND, false);
auto listener = GetListener();
if (listener == nullptr) {
CLOGE("listener is null");

View File

@ -34,31 +34,33 @@ void RtspChannelManager::ChannelListener::OnDataReceived(const uint8_t *buffer,
{
CLOGD("==============Received data length %{public}u timeCost %{public}ld================", length, timeCost);
if (!((channelManager_->algorithmId_ > 0) &&
!Utils::IsArrayAllZero(channelManager_->sessionKeys_, SESSION_KEY_LENGTH))) {
auto channelManager = channelManager_.lock();
if (channelManager == nullptr) {
CLOGE("channelManager == nullptr");
return;
}
if (!((channelManager->algorithmId_ > 0) &&
!Utils::IsArrayAllZero(channelManager->sessionKeys_, SESSION_KEY_LENGTH))) {
CLOGD("==============Not Authed Recv Msg ================");
CLOGD("Algorithm id %{public}d, length %{public}u.", channelManager_->algorithmId_, length);
channelManager_->OnData(buffer, length);
CLOGD("Algorithm id %{public}d, length %{public}u.", channelManager->algorithmId_, length);
channelManager->OnData(buffer, length);
} else {
unsigned int realPktlen = length - EncryptDecrypt::AES_IV_LEN;
std::unique_ptr<uint8_t[]> decryContent = std::make_unique<uint8_t[]>(realPktlen);
PacketData outputData = { decryContent.get(), 0 };
bool isSucc =
EncryptDecrypt::GetInstance().DecryptData(channelManager_->algorithmId_, channelManager_->sessionKeys_,
channelManager_->sessionKeyLength_, { buffer, static_cast<int>(length) }, outputData);
if (!isSucc) {
CLOGE("ERROR: decode fail or len [%{public}d],expect[%{public}u]", outputData.length, length);
int decryptDataLen = 0;
auto decryContent =
EncryptDecrypt::GetInstance().DecryptData(channelManager->algorithmId_, { channelManager->sessionKeys_,
channelManager->sessionKeyLength_ }, { buffer, static_cast<int>(length) }, decryptDataLen);
if (!decryContent) {
CLOGE("ERROR: decode fail, length[%{public}u]", length);
return;
}
CLOGD("==============Authed Recv Msg ================, decryContent length %{public}u", length);
channelManager_->OnData(decryContent.get(), realPktlen);
channelManager->OnData(decryContent.get(), decryptDataLen);
}
}
RtspChannelManager::RtspChannelManager(RtspListenerInner *listener, ProtocolType protocolType)
RtspChannelManager::RtspChannelManager(std::shared_ptr<RtspListenerInner> listener, ProtocolType protocolType)
: listener_(listener), protocolType_(protocolType)
{
channelListener_ = std::make_shared<ChannelListener>(this);
CLOGI("Out, ProtocolType:%{public}d", protocolType_);
}
@ -67,12 +69,14 @@ RtspChannelManager::~RtspChannelManager()
CLOGI("In.");
memset_s(sessionKeys_, SESSION_KEY_LENGTH, 0, SESSION_KEY_LENGTH);
channelListener_ = nullptr;
StopSafty(false);
ThreadJoin();
}
std::shared_ptr<IChannelListener> RtspChannelManager::GetChannelListener()
{
std::lock_guard<std::mutex> lock(mutex_);
if (channelListener_ == nullptr) {
channelListener_ = std::make_shared<ChannelListener>(shared_from_this());
}
return channelListener_;
}
@ -84,8 +88,13 @@ void RtspChannelManager::AddChannel(std::shared_ptr<Channel> channel, const Cast
channel_ = channel;
bool isSoftbus = channel->GetRequest().linkType == ChannelLinkType::SOFT_BUS;
CLOGD("LinkType %{public}d listener_ is %{public}d", isSoftbus, listener_ == nullptr);
listener_->OnPeerReady(isSoftbus);
auto listener = listener_.lock();
if (!listener) {
CLOGE("listener is nullptr");
return;
}
CLOGD("LinkType %{public}d listener_ is %{public}d", isSoftbus, listener == nullptr);
listener->OnPeerReady(isSoftbus);
}
void RtspChannelManager::RemoveChannel(std::shared_ptr<Channel> channel)
@ -108,49 +117,63 @@ void RtspChannelManager::StopSession()
if (isSessionActive_) {
memset_s(sessionKeys_, SESSION_KEY_LENGTH, 0, SESSION_KEY_LENGTH);
isSessionActive_ = false;
listener_->OnPeerGone();
auto listener = listener_.lock();
if (listener) {
listener->OnPeerGone();
}
}
RemoveMessage(Message(static_cast<int>(RtspState::MSG_NEG_TIMEOUT)));
}
void RtspChannelManager::OnConnected(ChannelLinkType channelLinkType)
{
if (listener_ == nullptr) {
CLOGE("listener is null.");
auto listener = listener_.lock();
if (!listener) {
CLOGE("listener is nullptr");
return;
}
bool isSoftbus = channelLinkType == ChannelLinkType::SOFT_BUS;
CLOGI("IsSoftbus %{public}d.", isSoftbus);
listener_->OnPeerReady(isSoftbus);
listener->OnPeerReady(isSoftbus);
}
void RtspChannelManager::OnData(const uint8_t *data, unsigned int length)
{
std::string str(reinterpret_cast<const char *>(data), length);
CLOGD("In, %{public}s %{public}s", (str.find("RTSP/") == 0) ? "Response...\r\n" : "Request...\r\n", str.c_str());
if (listener_ == nullptr) {
CLOGE("listener is null.");
auto listener = listener_.lock();
if (!listener) {
CLOGE("listener is nullptr");
return;
}
RtspParse msg;
RtspParse::ParseMsg(str, msg);
if (Utils::StartWith(str, "RTSP/")) {
listener_->OnResponse(msg);
listener->OnResponse(msg);
} else {
listener_->OnRequest(msg);
listener->OnRequest(msg);
}
}
void RtspChannelManager::OnError(const std::string &errorCode)
{
CLOGI("In, %{public}s.", errorCode.c_str());
listener_->OnPeerGone();
auto listener = listener_.lock();
if (!listener) {
CLOGE("listener is nullptr");
return;
}
listener->OnPeerGone();
}
void RtspChannelManager::OnClosed(const std::string &errorCode)
{
CLOGI("OnClosed %{public}s.", errorCode.c_str());
listener_->OnPeerGone();
auto listener = listener_.lock();
if (!listener) {
CLOGE("listener is nullptr");
return;
}
listener->OnPeerGone();
}
bool RtspChannelManager::SendData(const std::string &dataFrame)
@ -161,31 +184,23 @@ bool RtspChannelManager::SendData(const std::string &dataFrame)
return false;
}
size_t pktlen = dataFrame.size();
std::unique_ptr<uint8_t[]> encryptContent = std::make_unique<uint8_t[]>(pktlen + EncryptDecrypt::AES_IV_LEN);
PacketData outputData = { encryptContent.get(), 0 };
if (channel->GetRequest().linkType == ChannelLinkType::SOFT_BUS ||
Utils::IsArrayAllZero(sessionKeys_, SESSION_KEY_LENGTH) || algorithmId_ <= 0) {
errno_t ret = memcpy_s(encryptContent.get(), pktlen + EncryptDecrypt::AES_IV_LEN, dataFrame.c_str(), pktlen);
if (ret != EOK) {
CLOGE("ERROR: memory copy error:%{public}d", ret);
return false;
}
outputData.length = static_cast<int>(pktlen);
CLOGD("SendData, get data finish.");
} else {
bool ret = EncryptDecrypt::GetInstance().EncryptData(algorithmId_, sessionKeys_, sessionKeyLength_,
{ reinterpret_cast<const uint8_t *>(dataFrame.c_str()), pktlen }, outputData);
if (!ret || (outputData.length != static_cast<int>(pktlen) + static_cast<int>(EncryptDecrypt::AES_IV_LEN))) {
CLOGE("Encrypt data failed, dataLength: %{public}d, pktlen: %{public}zu", outputData.length, pktlen);
return false;
}
CLOGD("SendData, encrypt data finish.");
return channel->Send(reinterpret_cast<const uint8_t *>(dataFrame.c_str()), pktlen);
}
int encryptedDataLen = 0;
auto encryptedData = EncryptDecrypt::GetInstance().EncryptData(algorithmId_, { sessionKeys_, sessionKeyLength_ },
{ reinterpret_cast<const uint8_t *>(dataFrame.c_str()), pktlen }, encryptedDataLen);
if (!encryptedData) {
CLOGE("Encrypt data failed, pktlen: %{public}zu", pktlen);
return false;
}
CLOGD("SendData, encrypt data finish.");
CLOGD("SendData, outputData.length %{public}d pktlen %{public}zu send buffer %{public}s.", outputData.length,
pktlen, encryptContent.get());
return channel->Send(encryptContent.get(), outputData.length);
CLOGD("SendData, encryptedDataLen %{public}d pktlen %{public}zu send buffer %{public}s.", encryptedDataLen,
pktlen, encryptedData.get());
return channel->Send(encryptedData.get(), encryptedDataLen);
}
bool RtspChannelManager::SendRtspData(const std::string &request)
@ -204,41 +219,11 @@ bool RtspChannelManager::SendRtspData(const std::string &request)
return SendData(request);
}
void RtspChannelManager::CfgNegTimeout(bool isClear)
{
CLOGI("In, %{public}d.", isClear);
if (isClear) {
RemoveMessage(Message(static_cast<int>(RtspState::MSG_NEG_TIMEOUT)));
return;
}
SendCastMessageDelayed(static_cast<int>(RtspState::MSG_NEG_TIMEOUT), KEEP_NEG_TIMEOUT_INTERVAL); // 10s
}
void RtspChannelManager::SetNegAlgorithmId(int algorithmId)
{
algorithmId_ = algorithmId;
CLOGI("SetNegAlgorithmId algorithmId %{public}d.", algorithmId);
}
void RtspChannelManager::HandleMessage(const Message &msg)
{
switch (static_cast<RtspState>(msg.what_)) {
case RtspState::MSG_RTSP_START:
case RtspState::MSG_RTSP_DATA:
case RtspState::MSG_RTSP_CLOSE:
case RtspState::MSG_SEND_KA:
case RtspState::MSG_KA_TIMEOUT:
case RtspState::MSG_NEG_TIMEOUT:
CLOGE("NEG timeout.");
if (listener_ != nullptr) {
listener_->OnPeerGone();
}
break;
default:
break;
}
}
} // namespace CastSessionRtsp
} // namespace CastEngineService
} // namespace CastEngine

View File

@ -21,7 +21,6 @@
#include <mutex>
#include "channel.h"
#include "handler.h"
#include "message.h"
#include "rtsp_listener_inner.h"
#include "cast_engine_common.h"
@ -30,10 +29,9 @@ namespace OHOS {
namespace CastEngine {
namespace CastEngineService {
namespace CastSessionRtsp {
class RtspChannelManager : public Handler, public Message {
class RtspChannelManager : public Message, public std::enable_shared_from_this<RtspChannelManager> {
public:
RtspChannelManager(RtspListenerInner *listener, ProtocolType protocolType);
RtspChannelManager(std::shared_ptr<RtspListenerInner> listener, ProtocolType protocolType);
~RtspChannelManager();
void OnConnected(ChannelLinkType channelLinkType);
@ -49,7 +47,6 @@ public:
std::shared_ptr<IChannelListener> GetChannelListener();
bool SendRtspData(const std::string &request);
void CfgNegTimeout(bool isClear);
void SetNegAlgorithmId(int algorithmId);
private:
@ -66,31 +63,29 @@ private:
class ChannelListener : public IChannelListener {
public:
explicit ChannelListener(RtspChannelManager *channelManager) : channelManager_(channelManager) {}
~ChannelListener()
{
channelManager_ = nullptr;
}
explicit ChannelListener(std::shared_ptr<RtspChannelManager> channelManager) : channelManager_(channelManager)
{}
~ChannelListener() {}
void OnDataReceived(const uint8_t *buffer, unsigned int length, long timeCost) final;
private:
RtspChannelManager *channelManager_;
std::weak_ptr<RtspChannelManager> channelManager_;
};
constexpr static int SESSION_KEY_LENGTH = 16;
bool SendData(const std::string &dataFrame);
void HandleMessage(const Message &msg) override;
uint8_t sessionKeys_[SESSION_KEY_LENGTH] = {0};
uint32_t sessionKeyLength_{ 0 };
RtspListenerInner *listener_;
std::weak_ptr<RtspListenerInner> listener_;
bool isSessionActive_{ false };
std::shared_ptr<Channel> channel_;
std::shared_ptr<ChannelListener> channelListener_;
int algorithmId_{ 0 };
ProtocolType protocolType_;
std::mutex mutex_;
};
} // namespace CastSessionRtsp
} // namespace CastEngineService

View File

@ -41,9 +41,6 @@ std::shared_ptr<IRtspController> IRtspController::GetInstance(std::shared_ptr<IR
RtspController::RtspController(std::shared_ptr<IRtspListener> listener, ProtocolType protocolType, EndType endType)
: protocolType_(protocolType), listener_(listener), endType_(endType)
{
rtspNetManager_ = std::make_unique<RtspChannelManager>(this, protocolType);
ResponseFuncMapInit();
RequestFuncMapInit();
CLOGI("Out, endType %{public}d", endType);
}
@ -309,19 +306,22 @@ bool RtspController::StopEngine()
return true;
}
std::string RtspController::ParseCipherItem(const std::string &item) const
std::set<std::string> RtspController::ParseCipherItem(const std::string &item) const
{
if (item.empty()) {
return "";
return {};
}
std::set<std::string> supportedCipherLists;
std::vector<std::string> cipherLists;
Utils::SplitString(item, cipherLists, ", ");
for (size_t index = 0; index < cipherLists.size(); index++) {
if (Utils::ToLower(cipherLists[index]) == EncryptDecrypt::GetInstance().PC_ENCRYPT_ALG) {
return EncryptDecrypt::GetInstance().PC_ENCRYPT_ALG;
if (Utils::ToLower(cipherLists[index]) == EncryptDecrypt::CIPHER_AES_CTR_128) {
supportedCipherLists.insert(EncryptDecrypt::CIPHER_AES_CTR_128);
} else if (Utils::ToLower(cipherLists[index]) == EncryptDecrypt::CIPHER_AES_GCM_128) {
supportedCipherLists.insert(EncryptDecrypt::CIPHER_AES_GCM_128);
}
}
return "";
return supportedCipherLists;
}
bool RtspController::ProcessAnnounceRequest(RtspParse &request)
@ -349,18 +349,26 @@ bool RtspController::ProcessAnnounceRequest(RtspParse &request)
int version = instance.GetVersion();
CLOGD("AuthNeg: Get algStr is %{public}s version %{public}d", encryptStr.c_str(), version);
std::string sendStr = ParseCipherItem(encryptStr);
if (sendStr.empty() && (listener_ != nullptr)) {
listener_->OnError(ERROR_CODE_DEFAULT);
std::set<std::string> cipherList = ParseCipherItem(encryptStr);
if (cipherList.empty() && (listener_ != nullptr)) {
CLOGE("cipherList is empty");
return false;
}
EncryptionParamInfo encryptionParamInfo{};
encryptionParamInfo.controlChannelAlgId = static_cast<uint32_t>(instance.GetEncryptMatch(sendStr));
encryptionParamInfo.dataChannelAlgId = static_cast<uint32_t>(instance.GetEncryptMatch(sendStr));
encryptionParamInfo.controlChannelAlgId = static_cast<uint32_t>(instance.GetControlEncryptCipher(cipherList));
encryptionParamInfo.dataChannelAlgId = static_cast<uint32_t>(instance.GetMediaEncryptCipher(cipherList));
negotiatedParamInfo_.SetEncryptionParamInfo(encryptionParamInfo);
if (endType_ == EndType::CAST_SOURCE) {
std::string sendStr;
for (auto &cipher : cipherList) {
if (sendStr.empty()) {
sendStr = cipher;
continue;
}
sendStr += ", " + cipher;
}
std::string req = RtspEncap::EncapAnnounce(sendStr, ++currentSeq_, version);
rtspNetManager_->SendRtspData(req);
waitRsp_ = WaitResponse::WAITING_RSP_ANNOUNCE;

View File

@ -76,7 +76,7 @@ private:
bool SendAction(ActionType type);
void ProcessSinkDeviceType(const std::string &content);
bool StopEngine();
std::string ParseCipherItem(const std::string &item) const;
std::set<std::string> ParseCipherItem(const std::string &item) const;
bool ProcessOptionRequest(RtspParse &request);
bool ProcessSetupRequest(RtspParse &request);
bool ProcessGetParameterRequestM3(RtspParse &request);

View File

@ -21,6 +21,7 @@
#include <cstdint>
#include <string>
#include <set>
#include "openssl/hmac.h"
#include "openssl/err.h"
@ -44,14 +45,17 @@ class EncryptDecrypt final {
public:
static EncryptDecrypt &GetInstance();
bool EncryptData(int algCode, const uint8_t *key, int keyLen, ConstPacketData inputData, PacketData &outputData);
bool DecryptData(int algCode, const uint8_t *key, int keyLen, ConstPacketData inputData, PacketData &outputData);
std::unique_ptr<uint8_t[]> EncryptData(int algCode, ConstPacketData sessionKey,
ConstPacketData inputData, int &outLen);
std::unique_ptr<uint8_t[]> DecryptData(int algCode, ConstPacketData sessionKey,
ConstPacketData inputData, int &outLen);
std::string GetEncryptInfo();
int GetEncryptMatch(const std::string &encyptInfo);
int GetMediaEncryptCipher(const std::set<std::string> &cipherList);
int GetControlEncryptCipher(const std::set<std::string> &cipherList);
int GetVersion();
static const int AES_KEY_LEN_128 = 16;
static const unsigned int AES_IV_LEN = 16;
static const int AES_IV_LEN = 16;
static const int AES_KEY_LEN = 16;
static const int AES_KEY_SIZE = 16;
static const int PC_ENCRYPT_LEN = 64;
@ -61,7 +65,8 @@ public:
static const int CTR_CODE = 1;
static const int GCM_CODE = 2;
static const std::string PC_ENCRYPT_ALG;
static const std::string CIPHER_AES_CTR_128;
static const std::string CIPHER_AES_GCM_128;
private:
enum ErrorCode : int {
@ -97,7 +102,7 @@ private:
SEC_ERR_SETAAD_FAIL,
};
static const int AES_GCM_MAX_IVLEN = 12;
static const int AES_GCM_MIN_IVLEN = 12;
static const int AES_GCM_SIV_TAG_LEN = 16;
static const int UNSIGNED_CHAR_MIN = 0;
static const int UNSIGNED_CHAR_MAX = 255;
@ -111,8 +116,8 @@ private:
ConstPacketData sessionIV);
int AES128GCMCheckEncryPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
int AES128GCMCheckDecryptPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
int EnctyptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo, EVP_CIPHER_CTX *ctx);
int DecryptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo, EVP_CIPHER_CTX *ctx);
int EnctyptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
int DecryptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
int AES128GCMEncry(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
int AES128GCMDecrypt(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
};

View File

@ -27,7 +27,8 @@ namespace CastEngine {
namespace CastEngineService {
DEFINE_CAST_ENGINE_LABEL("Cast-EncryptDecrypt");
const std::string EncryptDecrypt::PC_ENCRYPT_ALG = "aes128ctr";
const std::string EncryptDecrypt::CIPHER_AES_CTR_128 = "aes128ctr";
const std::string EncryptDecrypt::CIPHER_AES_GCM_128 = "aes128gcm";
EncryptDecrypt::EncryptDecrypt() {}
@ -159,42 +160,43 @@ int EncryptDecrypt::AES128Decrypt(ConstPacketData inputData, PacketData &outputD
int EncryptDecrypt::AES128GCMCheckEncryPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
{
if ((inputData.length < 0) || (encryInfo.aad.length < 0) || (encryInfo.key.length != AES_KEY_LEN_128) ||
(encryInfo.iv.length < AES_GCM_MAX_IVLEN) || (encryInfo.tag.length < AES_GCM_SIV_TAG_LEN)) {
if (encryInfo.key.length != AES_KEY_LEN_128) {
return SEC_ERR_INVALID_KEY_LEN;
}
if ((inputData.data == nullptr) && (inputData.length > 0)) {
if (encryInfo.key.data == nullptr) {
return SEC_ERR_INVALID_KEY;
}
if (inputData.data == nullptr || inputData.length <= 0) {
return SEC_ERR_INVALID_PLAIN;
}
if ((encryInfo.aad.data == nullptr) && (encryInfo.aad.length > 0)) {
return SEC_ERR_INVALID_AAD;
if (encryInfo.iv.length < AES_GCM_MIN_IVLEN) {
return SEC_ERR_INVALID_IV_LEN;
}
if ((encryInfo.key.data == nullptr) || (encryInfo.iv.data == nullptr) || (encryInfo.tag.data == nullptr)) {
if (encryInfo.iv.data == nullptr) {
return SEC_ERR_INVALID_IV;
}
if (outputData.length < inputData.length) {
if (outputData.length < inputData.length || outputData.length <= AES_GCM_SIV_TAG_LEN) {
return SEC_ERR_INVALID_DATA_LEN;
}
if ((outputData.data == nullptr) && (outputData.length > 0)) {
if (outputData.data == nullptr || outputData.length <= 0) {
return SEC_ERR_INVALID_CIPHERTEXT;
}
return 0;
}
int EncryptDecrypt::EnctyptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo,
EVP_CIPHER_CTX *ctx)
int EncryptDecrypt::EnctyptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
{
int ret = 0;
int len = 0;
int cipherTextLen = 0;
EVP_CIPHER_CTX *ctx = nullptr;
// Enctypt
do {
/* Create and initialise the context */
ctx = EVP_CIPHER_CTX_new();
if (ctx == nullptr) {
ret = SEC_ERR_CREATECIPHER_FAIL;
break;
return SEC_ERR_CREATECIPHER_FAIL;
}
/* Initialise the encryption operation. */
@ -209,17 +211,14 @@ int EncryptDecrypt::EnctyptProcess(ConstPacketData inputData, PacketData &output
break;
}
/* Initialise key and IV */
if (EVP_EncryptInit_ex(ctx, nullptr, nullptr, encryInfo.key.data, encryInfo.iv.data) != 1) {
ret = SEC_ERR_INVALID_KEY;
if (!EVP_CIPHER_CTX_set_key_length(ctx, encryInfo.key.length)) {
ret = SEC_ERR_INVALID_KEY_LEN;
break;
}
/*
* Provide any AAD data. This can be called zero or more times as required
*/
if (EVP_EncryptUpdate(ctx, nullptr, &len, encryInfo.aad.data, encryInfo.aad.length) != 1) {
ret = SEC_ERR_SETAAD_FAIL;
/* Initialise key and IV */
if (EVP_EncryptInit_ex(ctx, nullptr, nullptr, encryInfo.key.data, encryInfo.iv.data) != 1) {
ret = SEC_ERR_INVALID_KEY;
break;
}
@ -246,73 +245,80 @@ int EncryptDecrypt::EnctyptProcess(ConstPacketData inputData, PacketData &output
}
cipherTextLen += len;
if (cipherTextLen + AES_GCM_SIV_TAG_LEN < outputData.length) {
ret = SEC_ERR_INVALID_DATA_LEN;
break;
}
/* Get the tag */
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, encryInfo.tag.length, encryInfo.tag.data) != 1) {
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, AES_GCM_SIV_TAG_LEN, outputData.data + cipherTextLen) != 1) {
ret = SEC_ERR_GCMGETTAG_FAIL;
break;
}
outputData.length = cipherTextLen;
ret = 0;
} while (0);
outputData.length = cipherTextLen + AES_GCM_SIV_TAG_LEN;
} while (0);
EVP_CIPHER_CTX_free(ctx);
ctx = nullptr;
return ret;
}
int EncryptDecrypt::AES128GCMEncry(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
{
EVP_CIPHER_CTX *ctx = nullptr;
int ret = AES128GCMCheckEncryPara(inputData, outputData, encryInfo);
if (ret != 0) {
return ret;
}
ret = EnctyptProcess(inputData, outputData, encryInfo, ctx);
ret = EnctyptProcess(inputData, outputData, encryInfo);
if (ctx != nullptr) {
EVP_CIPHER_CTX_free(ctx);
ctx = nullptr;
}
return ret;
}
int EncryptDecrypt::AES128GCMCheckDecryptPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
{
if ((inputData.length < 0) || (encryInfo.aad.length < 0) || (encryInfo.key.length != AES_KEY_LEN_128) ||
(encryInfo.iv.length < AES_GCM_MAX_IVLEN) || (encryInfo.tag.length < AES_GCM_SIV_TAG_LEN)) {
if (encryInfo.key.length != AES_KEY_LEN_128) {
return SEC_ERR_INVALID_KEY_LEN;
}
if ((inputData.data == nullptr) && (inputData.length > 0)) {
return SEC_ERR_INVALID_CIPHERTEXT;
if (encryInfo.key.data == nullptr) {
return SEC_ERR_INVALID_KEY;
}
if ((encryInfo.aad.data == nullptr) && (encryInfo.aad.length > 0)) {
return SEC_ERR_INVALID_AAD;
if (encryInfo.iv.length < AES_GCM_MIN_IVLEN) {
return SEC_ERR_INVALID_IV_LEN;
}
if ((encryInfo.key.data == nullptr) || (encryInfo.iv.data == nullptr) || (encryInfo.tag.data == nullptr)) {
if (encryInfo.iv.data == nullptr) {
return SEC_ERR_INVALID_IV;
}
if (encryInfo.tag.data == nullptr) {
return SEC_ERR_NULL_PTR;
}
if (encryInfo.tag.length < AES_GCM_SIV_TAG_LEN) {
return SEC_ERR_INVALID_DATA_LEN;
}
if (inputData.data == nullptr || inputData.length <= 0) {
return SEC_ERR_INVALID_CIPHERTEXT;
}
if (outputData.length < inputData.length) {
return SEC_ERR_INVALID_DATA_LEN;
}
if ((outputData.data == nullptr) && (outputData.length > 0)) {
if (outputData.data == nullptr || outputData.length <= 0) {
return SEC_ERR_INVALID_PLAIN;
}
return 0;
}
int EncryptDecrypt::DecryptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo,
EVP_CIPHER_CTX *ctx)
int EncryptDecrypt::DecryptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
{
int ret = 0;
int len = 0;
int plainTextLen = 0;
EVP_CIPHER_CTX *ctx = nullptr;
do {
/* Create and initialise the context */
ctx = EVP_CIPHER_CTX_new();
if (ctx == nullptr) {
ret = SEC_ERR_CREATECIPHER_FAIL;
break;
return SEC_ERR_CREATECIPHER_FAIL;
}
/* Initialise the decryption operation. */
@ -327,18 +333,15 @@ int EncryptDecrypt::DecryptProcess(ConstPacketData inputData, PacketData &output
break;
}
/* Initialise key and IV */
if (!EVP_DecryptInit_ex(ctx, nullptr, nullptr, encryInfo.key.data, encryInfo.iv.data)) {
ret = SEC_ERR_INVALID_KEY;
if (!EVP_CIPHER_CTX_set_key_length(ctx, encryInfo.key.length)) {
CLOGE("key length does not match key algorithm");
ret = SEC_ERR_INVALID_KEY_LEN;
break;
}
/*
* Provide any AAD data. This can be called zero or more times as
* required
*/
if (!EVP_DecryptUpdate(ctx, nullptr, &len, encryInfo.aad.data, encryInfo.aad.length)) {
ret = SEC_ERR_SETAAD_FAIL;
/* Initialise key and IV */
if (!EVP_DecryptInit_ex(ctx, nullptr, nullptr, encryInfo.key.data, encryInfo.iv.data)) {
ret = SEC_ERR_INVALID_KEY;
break;
}
@ -368,143 +371,153 @@ int EncryptDecrypt::DecryptProcess(ConstPacketData inputData, PacketData &output
}
plainTextLen += len;
outputData.length = plainTextLen;
ret = 0;
} while (0);
EVP_CIPHER_CTX_free(ctx);
ctx = nullptr;
return ret;
}
int EncryptDecrypt::AES128GCMDecrypt(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
{
EVP_CIPHER_CTX *ctx = nullptr;
int ret = AES128GCMCheckDecryptPara(inputData, outputData, encryInfo);
if (ret != 0) {
return ret;
}
ret = DecryptProcess(inputData, outputData, encryInfo, ctx);
if (ctx != nullptr) {
EVP_CIPHER_CTX_free(ctx);
ctx = nullptr;
}
ret = DecryptProcess(inputData, outputData, encryInfo);
return ret;
}
bool EncryptDecrypt::EncryptData(int algCode, const uint8_t *key, int keyLen, ConstPacketData inputData,
PacketData &outputData)
std::unique_ptr<uint8_t[]> EncryptDecrypt::EncryptData(int algCode, ConstPacketData sessionKey,
ConstPacketData inputData, int &outLen)
{
uint8_t sessionIV[AES_KEY_SIZE] = {0};
if (outputData.data == nullptr) {
CLOGE("outputData is null");
return false;
}
if (algCode != CTR_CODE) {
if (algCode != CTR_CODE && algCode != GCM_CODE) {
CLOGE("encrypt not CTR for extension");
return false;
return nullptr;
}
GetAESIv(sessionIV, AES_KEY_SIZE);
uint8_t sessionIV[AES_IV_LEN] = {0};
GetAESIv(sessionIV, AES_IV_LEN);
int encryptDataLen = inputData.length + AES_KEY_SIZE;
int encryptDataLen = inputData.length + AES_IV_LEN;
if (algCode == GCM_CODE) {
encryptDataLen += AES_GCM_SIV_TAG_LEN;
}
std::unique_ptr<uint8_t[]> encryptData = std::make_unique<uint8_t[]>(encryptDataLen);
if (encryptData == nullptr) {
return false;
return nullptr;
}
errno_t ret = memset_s(encryptData.get(), encryptDataLen, 0, encryptDataLen);
if (ret != 0) {
return false;
return nullptr;
}
PacketData output = { encryptData.get(), encryptDataLen };
ConstPacketData sessionKey = { key, keyLen };
PacketData output = { encryptData.get() + AES_IV_LEN, encryptDataLen - AES_IV_LEN };
ConstPacketData iv = { sessionIV, AES_IV_LEN };
ret = AES128Encry(inputData, output, sessionKey, iv);
if (ret != 0 || output.length > inputData.length) {
EncryptInfo encryInfo;
if (algCode == GCM_CODE) {
encryInfo.key = sessionKey;
encryInfo.iv = iv;
ret = AES128GCMEncry(inputData, output, encryInfo);
} else {
ret = AES128Encry(inputData, output, sessionKey, iv);
}
if (ret != 0) {
CLOGE("encrypt error enLen [%u][%u]", ret, output.length);
return false;
return nullptr;
}
ret = memcpy_s(outputData.data, AES_KEY_SIZE, sessionIV, AES_KEY_SIZE);
ret = memcpy_s(encryptData.get(), AES_IV_LEN, sessionIV, AES_IV_LEN);
if (ret != 0) {
return false;
}
ret = memcpy_s(outputData.data + AES_KEY_SIZE, output.length, output.data, output.length);
if (ret != 0) {
return false;
return nullptr;
}
outputData.length = output.length + AES_KEY_SIZE;
return true;
outLen = output.length + AES_IV_LEN;
return encryptData;
}
bool EncryptDecrypt::DecryptData(int algCode, const uint8_t *key, int keyLen, ConstPacketData inputData,
PacketData &outputData)
std::unique_ptr<uint8_t[]> EncryptDecrypt::DecryptData(int algCode, ConstPacketData sessionKey,
ConstPacketData inputData, int &outLen)
{
uint8_t sessionIV[AES_KEY_SIZE] = {0};
uint8_t sessionIV[AES_IV_LEN] = {0};
if (algCode != CTR_CODE) {
if (algCode != CTR_CODE && algCode != GCM_CODE) {
CLOGE("decrypt not CTR for extension");
return false;
return nullptr;
}
if (inputData.length <= AES_KEY_SIZE || inputData.data == nullptr || outputData.data == nullptr) {
CLOGE("decrypt para error");
return false;
int minLength = (algCode == GCM_CODE) ? (AES_IV_LEN + AES_GCM_SIV_TAG_LEN) : AES_IV_LEN;
if (inputData.length <= minLength || inputData.data == nullptr) {
CLOGE("decrypt para error, length:%{public}d", inputData.length);
return nullptr;
}
int32_t ret = memcpy_s(sessionIV, AES_KEY_SIZE, inputData.data, AES_KEY_SIZE);
int32_t ret = memcpy_s(sessionIV, AES_IV_LEN, inputData.data, AES_KEY_SIZE);
if (ret != 0) {
CLOGE("memcpy_s failed");
delete[] inputData.data;
return false;
return nullptr;
}
int deLen = inputData.length - AES_KEY_SIZE;
int deLen = inputData.length - AES_IV_LEN;
deLen = (algCode == GCM_CODE) ? (deLen - AES_GCM_SIV_TAG_LEN) : deLen;
std::unique_ptr<uint8_t[]> decryptData = std::make_unique<uint8_t[]>(deLen);
if (decryptData == nullptr) {
CLOGE("create decrypt data memory failed");
return false;
return nullptr;
}
ret = memset_s(decryptData.get(), deLen, 0, deLen);
if (ret != 0) {
CLOGE("memset_s failed");
return false;
return nullptr;
}
PacketData output = { decryptData.get(), deLen };
ConstPacketData sessionKey = { key, keyLen };
ConstPacketData iv = { sessionIV, AES_IV_LEN };
ConstPacketData input = { inputData.data + AES_KEY_SIZE, deLen };
ret = AES128Decrypt(input, output, sessionKey, iv);
if (algCode == GCM_CODE) {
EncryptInfo encryInfo;
encryInfo.key = sessionKey;
encryInfo.iv = iv;
encryInfo.tag = { const_cast<uint8_t *>(input.data + input.length), AES_GCM_SIV_TAG_LEN };
ret = AES128GCMDecrypt(input, output, encryInfo);
} else {
ret = AES128Decrypt(input, output, sessionKey, iv);
}
if (ret != 0 || output.length != deLen) {
CLOGE("decrypt error and ret[%{public}d] Len[%u] delen[%{public}d]", ret, output.length, deLen);
return false;
return nullptr;
}
ret = memcpy_s(outputData.data, output.length, output.data, output.length);
if (ret != 0) {
return false;
}
outputData.length = output.length;
outLen = output.length;
return true;
return decryptData;
}
std::string EncryptDecrypt::GetEncryptInfo()
{
return PC_ENCRYPT_ALG;
return CIPHER_AES_CTR_128;
}
int EncryptDecrypt::GetEncryptMatch(const std::string &encyptInfo)
int EncryptDecrypt::GetMediaEncryptCipher(const std::set<std::string> &cipherList)
{
if (encyptInfo.size() >= PC_ENCRYPT_LEN) {
return INVALID_CODE;
}
if (encyptInfo == PC_ENCRYPT_ALG) {
if (cipherList.count(CIPHER_AES_CTR_128) != 0) {
return CTR_CODE;
}
CLOGE("not support the cipherlist");
return INVALID_CODE;
}
int EncryptDecrypt::GetControlEncryptCipher(const std::set<std::string> &cipherList)
{
// GCM is preferred, followed by CTR
if (cipherList.count(CIPHER_AES_GCM_128) != 0) {
return GCM_CODE;
}
if (cipherList.count(CIPHER_AES_CTR_128) != 0) {
return CTR_CODE;
}
CLOGE("not support the cipherlist");
return INVALID_CODE;
}
int EncryptDecrypt::GetVersion()
{
return VERSION;