!106 更新发现连接方法及AES加解密

Merge pull request !106 from LongestDistance/master
This commit is contained in:
openharmony_ci 2024-11-15 16:19:05 +00:00 committed by Gitee
commit a3d15bad46
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 539 additions and 393 deletions

View File

@ -0,0 +1,135 @@
/*
* Copyright (C) 2023-2024 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* Description: implement the cast source connect
* Author: chenggong
* Create: 2023-10-23
*/
#ifndef CAST_META_NODE_CONSTANT_H
#define CAST_META_NODE_CONSTANT_H
#include <string>
namespace OHOS {
namespace CastEngine {
namespace CastEngineService {
const std::string KEY_BIND_TARGET_ACTION = "action";
constexpr int ACTION_CONNECT_DEVICE = 0;
constexpr int ACTION_QUERY_P2P_IP = 1;
constexpr int ACTION_SEND_MESSAGE = 2;
constexpr int ACTION_FINISH_BINDING_STATUS = 3;
const std::string KEY_LOCAL_NETWORK_ID = "localNetworkId";
const std::string KEY_REMOTE_NETWORK_ID = "remoteNetworkId";
const std::string KEY_LOCAL_WIFI_IP = "localWifiIp";
const std::string KEY_REMOTE_WIFI_IP = "remoteWifiIp";
const std::string KEY_LOCAL_P2P_IP = "localP2PIp";
const std::string KEY_REMOTE_P2P_IP = "remoteP2PIp";
const std::string NETWORK_ID = "networkId";
const std::string KEY_TRANSFER_MODE = "transferMode";
const std::string KEY_SESSION_ID = "sessionId";
const std::string PORT_KEY = "port";
const std::string SOURCE_IP_KEY = "sourceIp";
const std::string SINK_IP_KEY = "sinkIp";
const std::string TYPE_SESSION_KEY = "sessionKey";
const std::string KEY_CONNECT_TYPE = "connectType";
const std::string CONNECT_TYPE_WIFI = "lan"; // Do not modify the string
const std::string CONNECT_TYPE_P2P = "p2p"; // Do not modify the string
const std::string KEY_PROJECTION_MODE = "projectionMode";
const std::string PROJECTION_MODE_MIRROR = "MIRROR";
const std::string PROJECTION_MODE_STREAM = "MEDIA_RESOURCE";
const std::string KEY_UNBIND_TARGET_CAUSE = "unbindCause";
constexpr int CAUSE_AUTH_SUCCESS = 0;
constexpr int CAUSE_AUTH_FAILED = 1;
constexpr int CAUSE_DISCONNECT = 2;
inline constexpr int OK = 0;
inline constexpr int INVALID_VALUE = -1;
constexpr int AUTH_MODE_GENERIC = 1;
constexpr int AUTH_MODE_PWD = 2;
constexpr int CAST_SESSION_KEY_LENGTH = 16;
const std::string TIME_RECORD = "TimeRecord";
const std::string TOTAL_AUTH_TIME = "totalAuthTime";
const std::string TIME_RECORD_STRING = "timeRecordString";
const std::string VERSION_OH = "OH1.0";
const std::string PACKAGE_NAME_KEY = "packageName";
const std::string VERSION_NAME_KEY = "castpluskitVersionName";
const std::string PRODUCT_NAME_KEY = "productName";
const std::string CAST_PLUS_APP_VERSION_KEY = "castplusAppVersion";
const std::string PRODUCT_ID_KEY = "productId";
const std::string IS_VIRGIN_SINK = "isVirginSink";
const std::string ALLOWED_ALWAYS = "allowedAlways";
const std::string DEVICE_CAST_SOURCE = "deviceCastSource";
const std::string AUTH_MODE_KEY = "authMode";
const std::string VERSION_KEY = "version";
const std::string OPERATION_TYPE_KEY = "operType";
const std::string SEQUENCE_NUMBER = "sequenceNumber";
const std::string DATA_KEY = "data";
const std::string DEVICE_SALT_KEY = "salt";
const std::string CONNECT_TYPE_PRIORITY_KEY = "ConnectTypePriority";
const std::string TRIGGER_TYPE_KEY = "triggerType";
const std::string HANDSHAKE_RESULT_KEY = "handshakeResult";
const std::string DEVICE_IP_KEY = "deviceIp";
const std::string CURRENT_TIME_KEY = "currentTime";
const std::string KIT_AUTH_STATUS_KEY = "kitAuthStatus";
const std::string KIT_AUTH_POLICY_KEY = "kitAuthPolicy";
const std::string IS_GENERIC_TRUSTED_KEY = "isGenericTrusted";
const std::string IS_PWD_TRUSTED_KEY = "isPwdTrusted";
const std::string IS_SAME_ACCOUNT = "isSameAccount";
const std::string IS_SAME_FAMILY_GROUP = "isSameFamilyGroup";
const std::string IS_CONFIRMED_KEY = "isConfirmed";
const std::string DEVICE_ONLINE_INFO_KEY = "onlineInfo";
const std::string COORDINATION_CAPABILITY_KEY = "hasCoordinationCapability";
const std::string KIT_AUTH_CHOICE_KEY = "kitAuthChoice";
const std::string PHONE_MODEL_VERSION_KEY = "phoneModelVersion";
const std::string MODEL_NAME_KEY = "modelName";
const std::string CONSULT_RESULT = "consultResult";
const std::string SUB_REASON = "subReason";
const std::string SN_KEY = "serial_number";
const std::string CAPABILITY_INFO_KEY = "capabilityInfo";
// consult key
const std::string DEVICE_ID_KEY = "deviceId";
const std::string DEVICE_NAME_KEY = "deviceName";
const std::string REMOTE_DEVICE_NAME_KEY = "remoteDeviceName";
const std::string PROTOCOL_TYPE_KEY = "protocolType";
const std::string ACCOUNT_ID_KEY = "accountId";
const std::string USER_ID_KEY = "userId";
const std::string KIT_VER_PREFIX = "castpluskit";
// Prefix of third party castplus kit
const std::string KIT_VER_1X = "castpluskit 1";
// Prefix of HMOS2.x castplus kit
const std::string KIT_VER_2X = "castpluskit 2";
enum OperationStep : int32_t {
OPERATION_HANDSHAKE = 1,
OPERATION_AUTHENTICATE = 2,
OPERATION_CONSULT = 3,
};
enum DmAuthStatusExt : int32_t {
HANDSHAKE_RESPONSE = 1000,
};
} // namespace CastEngineService
} // namespace CastEngine
} // namespace OHOS
#endif

View File

@ -96,7 +96,6 @@ private:
ServiceStatus serviceStatus_{ ServiceStatus::DISCONNECTED }; ServiceStatus serviceStatus_{ ServiceStatus::DISCONNECTED };
int sessionCapacity_{ 0 }; int sessionCapacity_{ 0 };
std::map<int32_t, sptr<ICastSessionImpl>> sessionMap_; std::map<int32_t, sptr<ICastSessionImpl>> sessionMap_;
std::atomic<int> sessionIndex_{ 0 };
std::unordered_map<pid_t, sptr<IRemoteObject::DeathRecipient>> deathRecipientMap_; std::unordered_map<pid_t, sptr<IRemoteObject::DeathRecipient>> deathRecipientMap_;
std::atomic<bool> hasServer_{ false }; std::atomic<bool> hasServer_{ false };
std::atomic<bool> isUnloading_{ false }; std::atomic<bool> isUnloading_{ false };

View File

@ -560,13 +560,9 @@ int32_t CastSessionManagerService::SetSinkSessionCapacity(int sessionCapacity)
int32_t CastSessionManagerService::StartDiscovery(int protocols, std::vector<std::string> drmSchemes) int32_t CastSessionManagerService::StartDiscovery(int protocols, std::vector<std::string> drmSchemes)
{ {
static_cast<void>(protocols);
CLOGI("StartDiscovery in, protocolType = %{public}d, drm shcheme size = %{public}zu", protocols, drmSchemes.size()); CLOGI("StartDiscovery in, protocolType = %{public}d, drm shcheme size = %{public}zu", protocols, drmSchemes.size());
SharedRLock lock(mutex_);
if (!Permission::CheckPidPermission()) { DiscoveryManager::GetInstance().StartDiscovery(protocols, drmSchemes);
return ERR_NO_PERMISSION;
}
DiscoveryManager::GetInstance().StartDiscovery();
return CAST_ENGINE_SUCCESS; return CAST_ENGINE_SUCCESS;
} }

View File

@ -95,7 +95,7 @@ public:
void Init(std::shared_ptr<IDiscoveryManagerListener> listener); void Init(std::shared_ptr<IDiscoveryManagerListener> listener);
void Deinit(); void Deinit();
void StartDiscovery(); void StartDiscovery(int protocols, std::vector<std::string> drmSchemes);
void StopDiscovery(); void StopDiscovery();
bool StartAdvertise(); bool StartAdvertise();
@ -137,13 +137,15 @@ private:
std::mutex mutex_; std::mutex mutex_;
int32_t uid_{ 0 }; int32_t uid_{ 0 };
int protocolType_ = 0; bool isNotifyDevice_{ false };
int protocolType_;
std::vector<std::string> drmSchemes_; std::vector<std::string> drmSchemes_;
std::shared_ptr<IDiscoveryManagerListener> listener_; std::shared_ptr<IDiscoveryManagerListener> listener_;
std::shared_ptr<EventRunner> eventRunner_; std::shared_ptr<EventRunner> eventRunner_;
std::shared_ptr<DiscoveryEventHandler> eventHandler_; std::shared_ptr<DiscoveryEventHandler> eventHandler_;
std::unordered_map<CastInnerRemoteDevice, int> remoteDeviceMap_; std::unordered_map<CastInnerRemoteDevice, int> remoteDeviceMap_;
int32_t scanCount_; int32_t scanCount_;
bool hasStartDiscovery_ = false;
}; };
} // namespace CastEngineService } // namespace CastEngineService
} // namespace CastEngine } // namespace CastEngine

View File

@ -321,16 +321,16 @@ RemoteDeviceState CastDeviceDataManager::GetDeviceState(const std::string &devic
return GetDeviceStateLocked(deviceId); return GetDeviceStateLocked(deviceId);
} }
bool CastDeviceDataManager::IsDeviceConnecting(const std::string &deviceId)
{
return GetDeviceState(deviceId) == RemoteDeviceState::CONNECTING;
}
bool CastDeviceDataManager::IsDeviceConnected(const std::string &deviceId) bool CastDeviceDataManager::IsDeviceConnected(const std::string &deviceId)
{ {
return GetDeviceState(deviceId) == RemoteDeviceState::CONNECTED; return GetDeviceState(deviceId) == RemoteDeviceState::CONNECTED;
} }
bool CastDeviceDataManager::IsDeviceConnecting(const std::string &deviceId)
{
return GetDeviceState(deviceId) == RemoteDeviceState::CONNECTING;
}
bool CastDeviceDataManager::IsDeviceUsed(const std::string &deviceId) bool CastDeviceDataManager::IsDeviceUsed(const std::string &deviceId)
{ {
auto state = GetDeviceState(deviceId); auto state = GetDeviceState(deviceId);

View File

@ -106,8 +106,7 @@ SubDeviceType ConvertSubDeviceType(uint16_t deviceTypeId)
const std::string AUTH_VERSION_KEY = "authVersion"; const std::string AUTH_VERSION_KEY = "authVersion";
const std::string AUTH_VERSION_1 = "1.0"; const std::string AUTH_VERSION_1 = "1.0";
const std::string AUTH_VERSION_2 = "2.0"; const std::string AUTH_VERSION_2 = "2.0";
const std::string AUTH_VERSION_3 = "DM";
constexpr int THIRD_TV = 0x2E;
const std::string KEY_BIND_TARGET_ACTION = "action"; const std::string KEY_BIND_TARGET_ACTION = "action";
constexpr int ACTION_CONNECT_DEVICE = 0; constexpr int ACTION_CONNECT_DEVICE = 0;
@ -118,12 +117,6 @@ const std::string KEY_LOCAL_P2P_IP = "localP2PIp";
const std::string KEY_REMOTE_P2P_IP = "remoteP2PIp"; const std::string KEY_REMOTE_P2P_IP = "remoteP2PIp";
const std::string NETWORK_ID = "networkId"; const std::string NETWORK_ID = "networkId";
constexpr static int SECOND_BYTE_OFFSET = 8;
constexpr static int THIRD_BYTE_OFFSET = 16;
constexpr static int FOURTH_BYTE_OFFSET = 24;
constexpr static int INT_FOUR = 4;
void DeviceDiscoveryWriteWrap(const std::string& funcName, const std::string& puid) void DeviceDiscoveryWriteWrap(const std::string& funcName, const std::string& puid)
{ {
HiSysEventWriteWrap(funcName, { HiSysEventWriteWrap(funcName, {
@ -151,20 +144,6 @@ void EstablishConsultWriteWrap(const std::string& funcName, int sceneType, const
{"PEER_UDID", puid}}); {"PEER_UDID", puid}});
} }
void DeviceAuthWriteWrap(const std::string& funcName, int sceneType, const std::string& puid)
{
HiSysEventWriteWrap(funcName, {
{"BIZ_SCENE", sceneType},
{"BIZ_STATE", static_cast<int32_t>(BIZStateType::BIZ_STATE_BEGIN)},
{"BIZ_STAGE", static_cast<int32_t>(BIZSceneStage::DEVICE_AUTHENTICATION)},
{"STAGE_RES", static_cast<int32_t>(StageResType::STAGE_RES_SUCCESS)},
{"ERROR_CODE", CAST_RADAR_SUCCESS}}, {
{"TO_CALL_PKG", DEVICE_MANAGER_NAME},
{"LOCAL_SESS_NAME", ""},
{"PEER_SESS_NAME", ""},
{"PEER_UDID", puid}});
}
} // namespace } // namespace
namespace SoftBus { namespace SoftBus {
@ -289,6 +268,11 @@ int BindSocket(int32_t socketId, const ProtocolType &protocolType, bool isSingle
return SOFTBUS_OK; return SOFTBUS_OK;
} }
} // namespace SoftBus } // namespace SoftBus
constexpr static int SECOND_BYTE_OFFSET = 8;
constexpr static int THIRD_BYTE_OFFSET = 16;
constexpr static int FOURTH_BYTE_OFFSET = 24;
constexpr static int INT_FOUR = 4;
/* /*
* auth success * auth success
@ -312,7 +296,7 @@ const std::string USER_ID_KEY = "userId";
* User's unusual action or other event scenarios could cause changing of STATE or RESULT which delivered * User's unusual action or other event scenarios could cause changing of STATE or RESULT which delivered
* by DM. * by DM.
*/ */
const std::map<int32_t, int32_t> CastBindTargetCallback::RESULT_REASON_MAP = { const std::map<int, int32_t> CastBindTargetCallback::RESULT_REASON_MAP = {
// SINK peer click distrust button during 3-state authentication. // SINK peer click distrust button during 3-state authentication.
{ ERR_DM_AUTH_PEER_REJECT, REASON_DISTRUST_BY_SINK }, { ERR_DM_AUTH_PEER_REJECT, REASON_DISTRUST_BY_SINK },
// SINK peer click cancel button during pin code inputting. // SINK peer click cancel button during pin code inputting.
@ -325,7 +309,7 @@ const std::map<int32_t, int32_t> CastBindTargetCallback::RESULT_REASON_MAP = {
{ STOP_BIND, REASON_STOP_BIND_BY_SOURCE } { STOP_BIND, REASON_STOP_BIND_BY_SOURCE }
}; };
const std::map<int32_t, int32_t> CastBindTargetCallback::STATUS_REASON_MAP = { const std::map<int, int32_t> CastBindTargetCallback::STATUS_REASON_MAP = {
// DEFAULT event // DEFAULT event
{ DmAuthStatus::STATUS_DM_AUTH_DEFAULT, REASON_DEFAULT }, { DmAuthStatus::STATUS_DM_AUTH_DEFAULT, REASON_DEFAULT },
// Sink peer click trust during 3-state authentication. // Sink peer click trust during 3-state authentication.
@ -494,29 +478,30 @@ void ConnectionManager::OnConsultDataReceived(int transportId, const void *data,
CLOGE("Failed to get DmDeviceInfo"); CLOGE("Failed to get DmDeviceInfo");
return; return;
} }
int castSessionId = INVALID_ID;
constexpr int32_t sleepTimeMs = 50; constexpr int32_t sleepTimeMs = 50;
constexpr int32_t retryTimes = 20; constexpr int32_t retryTimes = 20;
int32_t retryTime = 0; int castSessionId = GetCastSessionId(transportId);
for (int32_t retryTime = 1; castSessionId == INVALID_ID && retryTime < retryTimes; retryTime++) {
while (castSessionId == INVALID_ID) { CLOGD("Retry for the %d(th) time after sleeping %dms", retryTime, sleepTimeMs);
if (castSessionId != INVALID_ID || retryTime > retryTimes) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(sleepTimeMs)); std::this_thread::sleep_for(std::chrono::milliseconds(sleepTimeMs));
castSessionId = GetCastSessionId(transportId); castSessionId = GetCastSessionId(transportId);
retryTime++;
} }
if (castSessionId == INVALID_ID) { if (castSessionId == INVALID_ID) {
CLOGE("session id invalid"); CLOGE("Invalid CastSessionId!");
return; return;
} }
CLOGI("protocolType is %d", device->protocolType); CLOGI("protocolType is %d", device->protocolType);
if (device->protocolType == ProtocolType::CAST_PLUS_STREAM) { if (device->protocolType == ProtocolType::CAST_PLUS_STREAM) {
SetSessionProtocolType(castSessionId, device->protocolType); SetSessionProtocolType(castSessionId, device->protocolType);
} }
if (!listener_) {
CLOGE("Detect absence of listener_.");
return;
}
listener_->ReportSessionCreate(castSessionId);
device->localCastSessionId = castSessionId;
if (!CastDeviceDataManager::GetInstance().AddDevice(*device, dmDevice)) { if (!CastDeviceDataManager::GetInstance().AddDevice(*device, dmDevice)) {
return; return;
} }
@ -535,13 +520,12 @@ void ConnectionManager::OnConsultDataReceived(int transportId, const void *data,
int ConnectionManager::GetCastSessionId(int transportId) int ConnectionManager::GetCastSessionId(int transportId)
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
for (const auto &element : transIdToCastSessionIdMap_) { if (transIdToCastSessionIdMap_.count(transportId) == 1) {
if (element.first == transportId) { return transIdToCastSessionIdMap_[transportId];
return element.second; } else {
} CLOGE("Invalid transport id:%{public}d", transportId);
return INVALID_ID;
} }
CLOGE("Invalid transport id:%{public}d", transportId);
return INVALID_ID;
} }
bool ConnectionManager::OnConsultSessionOpened(int transportId, bool isSource) bool ConnectionManager::OnConsultSessionOpened(int transportId, bool isSource)
@ -655,12 +639,10 @@ bool ConnectionManager::ConnectDevice(const CastInnerRemoteDevice &dev, const Pr
DeviceDiscoveryWriteWrap(__func__, GetAnonymousDeviceID(dev.deviceId)); DeviceDiscoveryWriteWrap(__func__, GetAnonymousDeviceID(dev.deviceId));
auto &deviceId = dev.deviceId; auto &deviceId = dev.deviceId;
CLOGI("ConnectDevice in, %s", deviceId.c_str()); CLOGI("deviceId %{public}s, protocolType %{public}d, capabilityInfo %{public}d, wifiIp %{public}s, "
"bleMac %{public}s, isLeagacy %{public}d, isFresh wifi %{public}d, ble %{public}d",
if (CastDeviceDataManager::GetInstance().IsDeviceUsed(deviceId)) { Utils::Mask(deviceId).c_str(), protocolType, dev.capabilityInfo, Utils::Mask(dev.wifiIp).c_str(),
CLOGD("Device: %s is used.", deviceId.c_str()); Utils::Mask(dev.bleMac).c_str(), dev.isLeagacy, dev.isWifiFresh, dev.isBleFresh);
return true;
}
if (!UpdateDeviceState(deviceId, RemoteDeviceState::CONNECTING)) { if (!UpdateDeviceState(deviceId, RemoteDeviceState::CONNECTING)) {
CLOGE("Device(%s) is missing", deviceId.c_str()); CLOGE("Device(%s) is missing", deviceId.c_str());
@ -672,7 +654,7 @@ bool ConnectionManager::ConnectDevice(const CastInnerRemoteDevice &dev, const Pr
if (IsNeedDiscoveryDevice(dev)) { if (IsNeedDiscoveryDevice(dev)) {
CLOGI("need discovery device"); CLOGI("need discovery device");
DiscoveryManager::GetInstance().StartDiscovery(); DiscoveryManager::GetInstance().StartDiscovery(static_cast<int>(protocolType), {});
std::thread([this, dev]() { std::thread([this, dev]() {
Utils::SetThreadName("ConnectTargetDevice"); Utils::SetThreadName("ConnectTargetDevice");
WaitAndConnectTargetDevice(dev); WaitAndConnectTargetDevice(dev);
@ -684,7 +666,7 @@ bool ConnectionManager::ConnectDevice(const CastInnerRemoteDevice &dev, const Pr
std::string networkId; std::string networkId;
if (IsDeviceTrusted(dev.deviceId, networkId) && IsSingle(dev) && SourceCheckConnectAccess(networkId)) { if (IsDeviceTrusted(dev.deviceId, networkId) && IsSingle(dev) && SourceCheckConnectAccess(networkId)) {
DeviceAuthWriteWrap(__func__, GetBIZSceneType(GetProtocolType()), GetAnonymousDeviceID(dev.deviceId)); NotifyListenerToLoadSinkSA(networkId);
if (!CastDeviceDataManager::GetInstance().SetDeviceNetworkId(deviceId, networkId) || if (!CastDeviceDataManager::GetInstance().SetDeviceNetworkId(deviceId, networkId) ||
!OpenConsultSession(dev)) { !OpenConsultSession(dev)) {
(void)UpdateDeviceState(deviceId, RemoteDeviceState::FOUND); (void)UpdateDeviceState(deviceId, RemoteDeviceState::FOUND);
@ -697,12 +679,13 @@ bool ConnectionManager::ConnectDevice(const CastInnerRemoteDevice &dev, const Pr
(void)UpdateDeviceState(deviceId, RemoteDeviceState::FOUND); (void)UpdateDeviceState(deviceId, RemoteDeviceState::FOUND);
return false; return false;
} }
std::unique_lock<std::mutex> lock(mutex_);
if (isBindTargetMap_.find(deviceId) != isBindTargetMap_.end()) { if (isBindTargetMap_.find(deviceId) != isBindTargetMap_.end()) {
isBindTargetMap_[deviceId] = true; isBindTargetMap_[deviceId] = true;
} else { } else {
isBindTargetMap_.insert({ deviceId, true }); isBindTargetMap_.insert({ deviceId, true });
} }
CLOGI("ConnectDevice out, %s", deviceId.c_str()); CLOGI("ConnectDevice out, %{public}s", Utils::Mask(deviceId).c_str());
return true; return true;
} }
@ -710,6 +693,8 @@ void ConnectionManager::DisconnectDevice(const std::string &deviceId)
{ {
CLOGI("DisconnectDevice in, deviceId %{public}s", Utils::Mask(deviceId).c_str()); CLOGI("DisconnectDevice in, deviceId %{public}s", Utils::Mask(deviceId).c_str());
std::unique_lock<std::mutex> lock(mutex_);
connectingDeviceId_ = "";
DiscoveryManager::GetInstance().StopDiscovery(); DiscoveryManager::GetInstance().StopDiscovery();
if (!CastDeviceDataManager::GetInstance().IsDeviceUsed(deviceId)) { if (!CastDeviceDataManager::GetInstance().IsDeviceUsed(deviceId)) {
CLOGE("Device(%s) is not used, remove it", deviceId.c_str()); CLOGE("Device(%s) is not used, remove it", deviceId.c_str());
@ -717,7 +702,11 @@ void ConnectionManager::DisconnectDevice(const std::string &deviceId)
return; return;
} }
protocolType_ = ProtocolType::CAST_PLUS_MIRROR;
lock.unlock();
UpdateDeviceState(deviceId, RemoteDeviceState::FOUND);
DestroyConsulationSession(deviceId); DestroyConsulationSession(deviceId);
CastDeviceDataManager::GetInstance().GetDeviceByDeviceId(deviceId);
auto isActiveAuth = CastDeviceDataManager::GetInstance().GetDeviceIsActiveAuth(deviceId); auto isActiveAuth = CastDeviceDataManager::GetInstance().GetDeviceIsActiveAuth(deviceId);
if (isActiveAuth == std::nullopt) { if (isActiveAuth == std::nullopt) {
return; return;
@ -778,12 +767,17 @@ int32_t ConnectionManager::GetLocalDeviceInfo(CastLocalDevice &device)
void ConnectionManager::NotifySessionIsReady(int transportId) void ConnectionManager::NotifySessionIsReady(int transportId)
{ {
if (!listener_) {
CLOGE("Detect absence of listener_.");
return;
}
int castSessionId = listener_->NotifySessionIsReady(); int castSessionId = listener_->NotifySessionIsReady();
if (castSessionId == INVALID_ID) { if (castSessionId == INVALID_ID) {
CLOGE("sessionId is invalid"); CLOGE("sessionId is invalid");
return; return;
} }
CLOGD("Update cast session id map: %d: %d", transportId, castSessionId);
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
transIdToCastSessionIdMap_.insert({ transportId, castSessionId }); transIdToCastSessionIdMap_.insert({ transportId, castSessionId });
} }
@ -791,11 +785,11 @@ void ConnectionManager::NotifySessionIsReady(int transportId)
void ConnectionManager::NotifyDeviceIsOffline(const std::string &deviceId) void ConnectionManager::NotifyDeviceIsOffline(const std::string &deviceId)
{ {
CLOGI("NotifyDeviceIsOffline in"); CLOGI("NotifyDeviceIsOffline in");
std::lock_guard<std::mutex> lock(mutex_); auto listener = GetListener();
if (!listener_) { if (!listener) {
return; return;
} }
listener_->NotifyDeviceIsOffline(deviceId); listener->NotifyDeviceIsOffline(deviceId);
} }
bool ConnectionManager::NotifyConnectStage(const CastInnerRemoteDevice &device, int result, int32_t reasonCode) bool ConnectionManager::NotifyConnectStage(const CastInnerRemoteDevice &device, int result, int32_t reasonCode)
@ -972,26 +966,12 @@ bool ConnectionManager::BindTarget(const CastInnerRemoteDevice &dev)
BuildBindParam(dev, bindParam); BuildBindParam(dev, bindParam);
int ret = DeviceManager::GetInstance().BindTarget(PKG_NAME, targetId, bindParam, int ret = DeviceManager::GetInstance().BindTarget(PKG_NAME, targetId, bindParam,
std::make_shared<CastBindTargetCallback>()); std::make_shared<CastBindTargetCallback>());
if (ret == ERR_DM_AUTH_BUSINESS_BUSY) {
CLOGE("bind fail, target is binding %d", ret);
auto networkId = CastDeviceDataManager::GetInstance().GetDeviceNetworkId(dev.deviceId);
PeerTargetId targetId = {
.deviceId = *networkId,
};
std::map<std::string, std::string> unbindParam{};
if (!CastDeviceDataManager::GetInstance().IsDoubleFrameDevice(dev.deviceId)) {
DeviceManager::GetInstance().UnbindTarget(PKG_NAME, targetId, unbindParam, nullptr);
} else {
unbindParam.insert(
std::pair<std::string, std::string>(PARAM_KEY_META_TYPE, std::to_string(5)));
DeviceManager::GetInstance().UnbindTarget(PKG_NAME, targetId, unbindParam, nullptr);
}
return false;
}
if (ret != DM_OK) { if (ret != DM_OK) {
CLOGE("ConnectDevice BindTarget fail, ret = %{public}d)", ret); CLOGE("ConnectDevice BindTarget fail, ret = %{public}d)", ret);
CastEngineDfx::WriteErrorEvent(AUTHENTICATE_DEVICE_FAIL); if (ret == ERR_DM_AUTH_BUSINESS_BUSY) {
DeviceManager::GetInstance().UnbindTarget(
PKG_NAME, targetId, bindParam, std::make_shared<CastUnBindTargetCallback>());
}
return false; return false;
} }
@ -1132,12 +1112,15 @@ void ConnectionManager::EncryptPort(int port, const uint8_t *sessionKey, json &b
int portArraySize = 4; int portArraySize = 4;
ConstPacketData inputData = { portArray.get(), portArraySize }; ConstPacketData inputData = { portArray.get(), portArraySize };
uint8_t encryptedPort[portArraySize + EncryptDecrypt::AES_IV_LEN]; int encryptedDataLen = 0;
PacketData outputData = { encryptedPort, 0 }; auto encryptedData = EncryptDecrypt::GetInstance().EncryptData(EncryptDecrypt::CTR_CODE, { sessionKey,
EncryptDecrypt::GetInstance().EncryptData(EncryptDecrypt::CTR_CODE, sessionKey, SESSION_KEY_LENGTH, inputData, SESSION_KEY_LENGTH }, inputData, encryptedDataLen);
outputData); if (!encryptedData) {
CLOGD("encrypt result is %d ", outputData.length); CLOGE("encrypt error");
std::string encryptedPortLatin1(reinterpret_cast<const char *>(outputData.data), outputData.length); return;
}
CLOGD("encrypt result is %d ", encryptedDataLen);
std::string encryptedPortLatin1(reinterpret_cast<const char *>(encryptedData.get()), encryptedDataLen);
std::string encryptedPortUtf8 = convLatin1ToUTF8(encryptedPortLatin1); std::string encryptedPortUtf8 = convLatin1ToUTF8(encryptedPortLatin1);
body[PORT_KEY] = encryptedPortUtf8; body[PORT_KEY] = encryptedPortUtf8;
} }
@ -1148,11 +1131,15 @@ void ConnectionManager::EncryptIp(const std::string &ip, const std::string &key,
return; return;
} }
ConstPacketData inputData = { reinterpret_cast<const uint8_t *>(ip.c_str()), ip.size() }; ConstPacketData inputData = { reinterpret_cast<const uint8_t *>(ip.c_str()), ip.size() };
uint8_t encrypted[ip.size() + EncryptDecrypt::AES_IV_LEN]; int encryptedDataLen = 0;
PacketData outputData = { encrypted, 0 }; auto encryptedData = EncryptDecrypt::GetInstance().EncryptData(EncryptDecrypt::CTR_CODE, { sessionKey,
EncryptDecrypt::GetInstance().EncryptData(EncryptDecrypt::CTR_CODE, sessionKey, SESSION_KEY_LENGTH, inputData, SESSION_KEY_LENGTH }, inputData, encryptedDataLen);
outputData); if (!encryptedData) {
for (int i = 0; i < outputData.length; i++) { CLOGE("encrypt error");
return;
}
uint8_t *encrypted = encryptedData.get();
for (int i = 0; i < encryptedDataLen; i++) {
body[key].push_back(encrypted[i]); body[key].push_back(encrypted[i]);
} }
CLOGI("encrypt %s finish", key.c_str()); CLOGI("encrypt %s finish", key.c_str());
@ -1160,12 +1147,13 @@ void ConnectionManager::EncryptIp(const std::string &ip, const std::string &key,
std::unique_ptr<uint8_t[]> ConnectionManager::intToByteArray(int32_t num) std::unique_ptr<uint8_t[]> ConnectionManager::intToByteArray(int32_t num)
{ {
unsigned int number = static_cast<unsigned int>(num);
std::unique_ptr<uint8_t[]> result = std::make_unique<uint8_t[]>(INT_FOUR); std::unique_ptr<uint8_t[]> result = std::make_unique<uint8_t[]>(INT_FOUR);
int i = 0; unsigned int i = 0;
result[i] = (num >> FOURTH_BYTE_OFFSET) & 0xFF; result[i] = (number >> FOURTH_BYTE_OFFSET) & 0xFF;
result[++i] = (num >> THIRD_BYTE_OFFSET) & 0xFF; result[++i] = (number >> THIRD_BYTE_OFFSET) & 0xFF;
result[++i] = (num >> SECOND_BYTE_OFFSET) & 0xFF; result[++i] = (number >> SECOND_BYTE_OFFSET) & 0xFF;
result[++i] = num & 0xFF; result[++i] = number & 0xFF;
return result; return result;
} }
@ -1401,20 +1389,20 @@ void ConnectionManager::ResetListener()
int32_t ConnectionManager::GetSessionProtocolType(int sessionId, ProtocolType &protocolType) int32_t ConnectionManager::GetSessionProtocolType(int sessionId, ProtocolType &protocolType)
{ {
std::lock_guard<std::mutex> lock(mutex_); auto listener = GetListener();
if (!listener_) { if (!listener) {
return CAST_ENGINE_ERROR; return CAST_ENGINE_ERROR;
} }
return listener_->GetSessionProtocolType(sessionId, protocolType); return listener->GetSessionProtocolType(sessionId, protocolType);
} }
int32_t ConnectionManager::SetSessionProtocolType(int sessionId, ProtocolType protocolType) int32_t ConnectionManager::SetSessionProtocolType(int sessionId, ProtocolType protocolType)
{ {
std::lock_guard<std::mutex> lock(mutex_); auto listener = GetListener();
if (!listener_) { if (!listener) {
return CAST_ENGINE_ERROR; return CAST_ENGINE_ERROR;
} }
return listener_->SetSessionProtocolType(sessionId, protocolType); return listener->SetSessionProtocolType(sessionId, protocolType);
} }
void ConnectionManager::SendConsultInfo(const std::string &deviceId, int port) void ConnectionManager::SendConsultInfo(const std::string &deviceId, int port)
@ -1622,39 +1610,32 @@ void CastDeviceStateCallback::OnDeviceOnline(const DmDeviceInfo &deviceInfo)
void CastDeviceStateCallback::OnDeviceOffline(const DmDeviceInfo &deviceInfo) void CastDeviceStateCallback::OnDeviceOffline(const DmDeviceInfo &deviceInfo)
{ {
CLOGI("device(%s) is offline", deviceInfo.deviceId); CLOGI("device(%{public}s) is offline", Utils::Mask(deviceInfo.deviceId).c_str());
ConnectionManager::GetInstance().NotifyDeviceIsOffline(deviceInfo.deviceId);
} }
void CastDeviceStateCallback::OnDeviceChanged(const DmDeviceInfo &deviceInfo) void CastDeviceStateCallback::OnDeviceChanged(const DmDeviceInfo &deviceInfo)
{ {
CLOGI("device(%s) is changed", deviceInfo.deviceId); CLOGI("device(%{public}s) is changed", Utils::Mask(deviceInfo.deviceId).c_str());
} }
void CastDeviceStateCallback::OnDeviceReady(const DmDeviceInfo &deviceInfo) void CastDeviceStateCallback::OnDeviceReady(const DmDeviceInfo &deviceInfo)
{ {
CLOGI("device(%s) is ready", deviceInfo.deviceId); CLOGI("device(%{public}s) is ready", Utils::Mask(deviceInfo.deviceId).c_str());
} }
bool ConnectionManager::IsSingle(const CastInnerRemoteDevice &device) bool ConnectionManager::IsSingle(const CastInnerRemoteDevice &device)
{ {
if (device.deviceTypeId == THIRD_TV) { if (device.authVersion == AUTH_VERSION_3) {
return false; CLOGI("Is hw single device");
}
if (device.customData.empty() && device.wifiPort == 0 && device.bleMac.empty()) {
return true; return true;
} }
if (device.customData.empty()) {
return device.wifiPort != 0 || !device.bleMac.empty();
}
return false; return false;
} }
bool ConnectionManager::IsHuaweiDevice(const CastInnerRemoteDevice &device) bool ConnectionManager::IsHuaweiDevice(const CastInnerRemoteDevice &device)
{ {
if (!device.customData.empty()) { if (device.authVersion == AUTH_VERSION_2) {
CLOGI("Is hw device");
return true; return true;
} }
return false; return false;
@ -1662,10 +1643,12 @@ bool ConnectionManager::IsHuaweiDevice(const CastInnerRemoteDevice &device)
bool ConnectionManager::IsThirdDevice(const CastInnerRemoteDevice &device) bool ConnectionManager::IsThirdDevice(const CastInnerRemoteDevice &device)
{ {
if (device.deviceTypeId == THIRD_TV) { if (device.authVersion == AUTH_VERSION_1) {
CLOGI("Is third device");
return true; return true;
} }
return device.bleMac.empty() && device.wifiPort == 0;
return false;
} }
bool ConnectionManager::IsBindTarget(std::string deviceId) bool ConnectionManager::IsBindTarget(std::string deviceId)

View File

@ -138,7 +138,66 @@ void DiscoveryManager::Deinit()
StopDiscovery(); StopDiscovery();
ResetListener(); ResetListener();
DeviceManager::GetInstance().UnInitDeviceManager(PKG_NAME); DeviceManager::GetInstance().UnInitDeviceManager(PKG_NAME);
eventRunner_->Stop(); }
void DiscoveryManager::StartDiscovery(int protocols, std::vector<std::string> drmSchemes)
{
HiSysEventWriteWrap(__func__, {
{"BIZ_SCENE", static_cast<int32_t>(BIZSceneType::DEVICE_DISCOVERY)},
{"BIZ_STATE", static_cast<int32_t>(BIZStateType::BIZ_STATE_BEGIN)},
{"BIZ_STAGE", static_cast<int32_t>(BIZSceneStage::START_DISCOVERY)},
{"STAGE_RES", static_cast<int32_t>(StageResType::STAGE_RES_IDLE)},
{"ERROR_CODE", CAST_RADAR_SUCCESS}}, {
{"TO_CALL_PKG", ""},
{"LOCAL_SESS_NAME", ""},
{"PEER_SESS_NAME", ""},
{"PEER_UDID", ""}});
CLOGI("StartDiscovery in");
protocolType_ = protocols;
drmSchemes_ = drmSchemes;
CastLocalDevice localDevice;
ConnectionManager::GetInstance().GetLocalDeviceInfo(localDevice);
std::lock_guard<std::mutex> lock(mutex_);
std::thread([this]() {
Utils::SetThreadName("DiscoveryEventRunner");
if (eventRunner_ != nullptr) {
eventRunner_->Run();
}
}).detach();
scanCount_ = 0;
for (auto it = remoteDeviceMap_.begin(); it != remoteDeviceMap_.end();) {
std::string deviceId = it->first.deviceId;
CastDeviceDataManager::GetInstance().SetDeviceNotFresh(deviceId);
it++;
}
remoteDeviceMap_.clear();
std::string connectDeviceId = ConnectionManager::GetInstance().GetConnectingDeviceId();
if (!connectDeviceId.empty()) {
auto device = CastDeviceDataManager::GetInstance().GetDeviceByDeviceId(connectDeviceId);
if (device != std::nullopt) {
device->deviceName = "";
remoteDeviceMap_[*device] = scanCount_ + 1;
}
}
uid_ = IPCSkeleton::GetCallingUid();
hasStartDiscovery_ = true;
eventHandler_->RemoveEvent(EVENT_START_DISCOVERY);
eventHandler_->SendEvent(EVENT_START_DISCOVERY);
CLOGI("StartDiscovery out");
}
void DiscoveryManager::StopDiscovery()
{
CLOGI("StopDiscovery in");
hasStartDiscovery_ = false;
SetDeviceNotFresh();
eventHandler_->RemoveAllEvents();
if (eventRunner_ != nullptr) {
eventRunner_->Stop();
}
StopDmDiscovery();
} }
void DiscoveryManager::GetAndReportTrustedDevices() void DiscoveryManager::GetAndReportTrustedDevices()
@ -214,40 +273,6 @@ int DiscoveryManager::GetProtocolType() const
return protocolType_; return protocolType_;
} }
void DiscoveryManager::StartDiscovery()
{
HiSysEventWriteWrap(__func__, {
{"BIZ_SCENE", static_cast<int32_t>(BIZSceneType::DEVICE_DISCOVERY)},
{"BIZ_STATE", static_cast<int32_t>(BIZStateType::BIZ_STATE_BEGIN)},
{"BIZ_STAGE", static_cast<int32_t>(BIZSceneStage::START_DISCOVERY)},
{"STAGE_RES", static_cast<int32_t>(StageResType::STAGE_RES_IDLE)},
{"ERROR_CODE", CAST_RADAR_SUCCESS}}, {
{"TO_CALL_PKG", ""},
{"LOCAL_SESS_NAME", ""},
{"PEER_SESS_NAME", ""},
{"PEER_UDID", ""}});
CLOGI("StartDiscovery in");
scanCount_ = 0;
remoteDeviceMap_.clear();
std::lock_guard<std::mutex> lock(mutex_);
uid_ = IPCSkeleton::GetCallingUid();
eventHandler_->SendEvent(EVENT_START_DISCOVERY);
CLOGI("StartDiscovery out");
}
void DiscoveryManager::StopDiscovery()
{
CLOGI("StopDiscovery in");
SetDeviceNotFresh();
eventHandler_->RemoveAllEvents();
if (eventRunner_ != nullptr) {
eventRunner_->Stop();
}
StopDmDiscovery();
}
void DiscoveryManager::SetListener(std::shared_ptr<IDiscoveryManagerListener> listener) void DiscoveryManager::SetListener(std::shared_ptr<IDiscoveryManagerListener> listener)
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
@ -346,7 +371,7 @@ void DiscoveryManager::NotifyDeviceIsFound(const CastInnerRemoteDevice &newDevic
} }
devices.push_back(newDevice); devices.push_back(newDevice);
std::lock_guard<std::mutex> lock(mutex_); isNotifyDevice_ = system::GetBoolParameter(NOTIFY_DEVICE_FOUND, false);
auto listener = GetListener(); auto listener = GetListener();
if (listener == nullptr) { if (listener == nullptr) {
CLOGE("listener is null"); CLOGE("listener is null");

View File

@ -34,31 +34,33 @@ void RtspChannelManager::ChannelListener::OnDataReceived(const uint8_t *buffer,
{ {
CLOGD("==============Received data length %{public}u timeCost %{public}ld================", length, timeCost); CLOGD("==============Received data length %{public}u timeCost %{public}ld================", length, timeCost);
if (!((channelManager_->algorithmId_ > 0) && auto channelManager = channelManager_.lock();
!Utils::IsArrayAllZero(channelManager_->sessionKeys_, SESSION_KEY_LENGTH))) { if (channelManager == nullptr) {
CLOGE("channelManager == nullptr");
return;
}
if (!((channelManager->algorithmId_ > 0) &&
!Utils::IsArrayAllZero(channelManager->sessionKeys_, SESSION_KEY_LENGTH))) {
CLOGD("==============Not Authed Recv Msg ================"); CLOGD("==============Not Authed Recv Msg ================");
CLOGD("Algorithm id %{public}d, length %{public}u.", channelManager_->algorithmId_, length); CLOGD("Algorithm id %{public}d, length %{public}u.", channelManager->algorithmId_, length);
channelManager_->OnData(buffer, length); channelManager->OnData(buffer, length);
} else { } else {
unsigned int realPktlen = length - EncryptDecrypt::AES_IV_LEN; int decryptDataLen = 0;
std::unique_ptr<uint8_t[]> decryContent = std::make_unique<uint8_t[]>(realPktlen); auto decryContent =
PacketData outputData = { decryContent.get(), 0 }; EncryptDecrypt::GetInstance().DecryptData(channelManager->algorithmId_, { channelManager->sessionKeys_,
bool isSucc = channelManager->sessionKeyLength_ }, { buffer, static_cast<int>(length) }, decryptDataLen);
EncryptDecrypt::GetInstance().DecryptData(channelManager_->algorithmId_, channelManager_->sessionKeys_, if (!decryContent) {
channelManager_->sessionKeyLength_, { buffer, static_cast<int>(length) }, outputData); CLOGE("ERROR: decode fail, length[%{public}u]", length);
if (!isSucc) {
CLOGE("ERROR: decode fail or len [%{public}d],expect[%{public}u]", outputData.length, length);
return; return;
} }
CLOGD("==============Authed Recv Msg ================, decryContent length %{public}u", length); CLOGD("==============Authed Recv Msg ================, decryContent length %{public}u", length);
channelManager_->OnData(decryContent.get(), realPktlen); channelManager->OnData(decryContent.get(), decryptDataLen);
} }
} }
RtspChannelManager::RtspChannelManager(RtspListenerInner *listener, ProtocolType protocolType) RtspChannelManager::RtspChannelManager(std::shared_ptr<RtspListenerInner> listener, ProtocolType protocolType)
: listener_(listener), protocolType_(protocolType) : listener_(listener), protocolType_(protocolType)
{ {
channelListener_ = std::make_shared<ChannelListener>(this);
CLOGI("Out, ProtocolType:%{public}d", protocolType_); CLOGI("Out, ProtocolType:%{public}d", protocolType_);
} }
@ -67,12 +69,14 @@ RtspChannelManager::~RtspChannelManager()
CLOGI("In."); CLOGI("In.");
memset_s(sessionKeys_, SESSION_KEY_LENGTH, 0, SESSION_KEY_LENGTH); memset_s(sessionKeys_, SESSION_KEY_LENGTH, 0, SESSION_KEY_LENGTH);
channelListener_ = nullptr; channelListener_ = nullptr;
StopSafty(false);
ThreadJoin();
} }
std::shared_ptr<IChannelListener> RtspChannelManager::GetChannelListener() std::shared_ptr<IChannelListener> RtspChannelManager::GetChannelListener()
{ {
std::lock_guard<std::mutex> lock(mutex_);
if (channelListener_ == nullptr) {
channelListener_ = std::make_shared<ChannelListener>(shared_from_this());
}
return channelListener_; return channelListener_;
} }
@ -84,8 +88,13 @@ void RtspChannelManager::AddChannel(std::shared_ptr<Channel> channel, const Cast
channel_ = channel; channel_ = channel;
bool isSoftbus = channel->GetRequest().linkType == ChannelLinkType::SOFT_BUS; bool isSoftbus = channel->GetRequest().linkType == ChannelLinkType::SOFT_BUS;
CLOGD("LinkType %{public}d listener_ is %{public}d", isSoftbus, listener_ == nullptr); auto listener = listener_.lock();
listener_->OnPeerReady(isSoftbus); if (!listener) {
CLOGE("listener is nullptr");
return;
}
CLOGD("LinkType %{public}d listener_ is %{public}d", isSoftbus, listener == nullptr);
listener->OnPeerReady(isSoftbus);
} }
void RtspChannelManager::RemoveChannel(std::shared_ptr<Channel> channel) void RtspChannelManager::RemoveChannel(std::shared_ptr<Channel> channel)
@ -108,49 +117,63 @@ void RtspChannelManager::StopSession()
if (isSessionActive_) { if (isSessionActive_) {
memset_s(sessionKeys_, SESSION_KEY_LENGTH, 0, SESSION_KEY_LENGTH); memset_s(sessionKeys_, SESSION_KEY_LENGTH, 0, SESSION_KEY_LENGTH);
isSessionActive_ = false; isSessionActive_ = false;
listener_->OnPeerGone(); auto listener = listener_.lock();
if (listener) {
listener->OnPeerGone();
}
} }
RemoveMessage(Message(static_cast<int>(RtspState::MSG_NEG_TIMEOUT)));
} }
void RtspChannelManager::OnConnected(ChannelLinkType channelLinkType) void RtspChannelManager::OnConnected(ChannelLinkType channelLinkType)
{ {
if (listener_ == nullptr) { auto listener = listener_.lock();
CLOGE("listener is null."); if (!listener) {
CLOGE("listener is nullptr");
return; return;
} }
bool isSoftbus = channelLinkType == ChannelLinkType::SOFT_BUS; bool isSoftbus = channelLinkType == ChannelLinkType::SOFT_BUS;
CLOGI("IsSoftbus %{public}d.", isSoftbus); CLOGI("IsSoftbus %{public}d.", isSoftbus);
listener_->OnPeerReady(isSoftbus); listener->OnPeerReady(isSoftbus);
} }
void RtspChannelManager::OnData(const uint8_t *data, unsigned int length) void RtspChannelManager::OnData(const uint8_t *data, unsigned int length)
{ {
std::string str(reinterpret_cast<const char *>(data), length); std::string str(reinterpret_cast<const char *>(data), length);
CLOGD("In, %{public}s %{public}s", (str.find("RTSP/") == 0) ? "Response...\r\n" : "Request...\r\n", str.c_str()); CLOGD("In, %{public}s %{public}s", (str.find("RTSP/") == 0) ? "Response...\r\n" : "Request...\r\n", str.c_str());
if (listener_ == nullptr) { auto listener = listener_.lock();
CLOGE("listener is null."); if (!listener) {
CLOGE("listener is nullptr");
return; return;
} }
RtspParse msg; RtspParse msg;
RtspParse::ParseMsg(str, msg); RtspParse::ParseMsg(str, msg);
if (Utils::StartWith(str, "RTSP/")) { if (Utils::StartWith(str, "RTSP/")) {
listener_->OnResponse(msg); listener->OnResponse(msg);
} else { } else {
listener_->OnRequest(msg); listener->OnRequest(msg);
} }
} }
void RtspChannelManager::OnError(const std::string &errorCode) void RtspChannelManager::OnError(const std::string &errorCode)
{ {
CLOGI("In, %{public}s.", errorCode.c_str()); CLOGI("In, %{public}s.", errorCode.c_str());
listener_->OnPeerGone(); auto listener = listener_.lock();
if (!listener) {
CLOGE("listener is nullptr");
return;
}
listener->OnPeerGone();
} }
void RtspChannelManager::OnClosed(const std::string &errorCode) void RtspChannelManager::OnClosed(const std::string &errorCode)
{ {
CLOGI("OnClosed %{public}s.", errorCode.c_str()); CLOGI("OnClosed %{public}s.", errorCode.c_str());
listener_->OnPeerGone(); auto listener = listener_.lock();
if (!listener) {
CLOGE("listener is nullptr");
return;
}
listener->OnPeerGone();
} }
bool RtspChannelManager::SendData(const std::string &dataFrame) bool RtspChannelManager::SendData(const std::string &dataFrame)
@ -161,31 +184,23 @@ bool RtspChannelManager::SendData(const std::string &dataFrame)
return false; return false;
} }
size_t pktlen = dataFrame.size(); size_t pktlen = dataFrame.size();
std::unique_ptr<uint8_t[]> encryptContent = std::make_unique<uint8_t[]>(pktlen + EncryptDecrypt::AES_IV_LEN);
PacketData outputData = { encryptContent.get(), 0 };
if (channel->GetRequest().linkType == ChannelLinkType::SOFT_BUS || if (channel->GetRequest().linkType == ChannelLinkType::SOFT_BUS ||
Utils::IsArrayAllZero(sessionKeys_, SESSION_KEY_LENGTH) || algorithmId_ <= 0) { Utils::IsArrayAllZero(sessionKeys_, SESSION_KEY_LENGTH) || algorithmId_ <= 0) {
errno_t ret = memcpy_s(encryptContent.get(), pktlen + EncryptDecrypt::AES_IV_LEN, dataFrame.c_str(), pktlen);
if (ret != EOK) {
CLOGE("ERROR: memory copy error:%{public}d", ret);
return false;
}
outputData.length = static_cast<int>(pktlen);
CLOGD("SendData, get data finish."); CLOGD("SendData, get data finish.");
} else { return channel->Send(reinterpret_cast<const uint8_t *>(dataFrame.c_str()), pktlen);
bool ret = EncryptDecrypt::GetInstance().EncryptData(algorithmId_, sessionKeys_, sessionKeyLength_,
{ reinterpret_cast<const uint8_t *>(dataFrame.c_str()), pktlen }, outputData);
if (!ret || (outputData.length != static_cast<int>(pktlen) + static_cast<int>(EncryptDecrypt::AES_IV_LEN))) {
CLOGE("Encrypt data failed, dataLength: %{public}d, pktlen: %{public}zu", outputData.length, pktlen);
return false;
}
CLOGD("SendData, encrypt data finish.");
} }
int encryptedDataLen = 0;
auto encryptedData = EncryptDecrypt::GetInstance().EncryptData(algorithmId_, { sessionKeys_, sessionKeyLength_ },
{ reinterpret_cast<const uint8_t *>(dataFrame.c_str()), pktlen }, encryptedDataLen);
if (!encryptedData) {
CLOGE("Encrypt data failed, pktlen: %{public}zu", pktlen);
return false;
}
CLOGD("SendData, encrypt data finish.");
CLOGD("SendData, outputData.length %{public}d pktlen %{public}zu send buffer %{public}s.", outputData.length, CLOGD("SendData, encryptedDataLen %{public}d pktlen %{public}zu send buffer %{public}s.", encryptedDataLen,
pktlen, encryptContent.get()); pktlen, encryptedData.get());
return channel->Send(encryptContent.get(), outputData.length); return channel->Send(encryptedData.get(), encryptedDataLen);
} }
bool RtspChannelManager::SendRtspData(const std::string &request) bool RtspChannelManager::SendRtspData(const std::string &request)
@ -204,41 +219,11 @@ bool RtspChannelManager::SendRtspData(const std::string &request)
return SendData(request); return SendData(request);
} }
void RtspChannelManager::CfgNegTimeout(bool isClear)
{
CLOGI("In, %{public}d.", isClear);
if (isClear) {
RemoveMessage(Message(static_cast<int>(RtspState::MSG_NEG_TIMEOUT)));
return;
}
SendCastMessageDelayed(static_cast<int>(RtspState::MSG_NEG_TIMEOUT), KEEP_NEG_TIMEOUT_INTERVAL); // 10s
}
void RtspChannelManager::SetNegAlgorithmId(int algorithmId) void RtspChannelManager::SetNegAlgorithmId(int algorithmId)
{ {
algorithmId_ = algorithmId; algorithmId_ = algorithmId;
CLOGI("SetNegAlgorithmId algorithmId %{public}d.", algorithmId); CLOGI("SetNegAlgorithmId algorithmId %{public}d.", algorithmId);
} }
void RtspChannelManager::HandleMessage(const Message &msg)
{
switch (static_cast<RtspState>(msg.what_)) {
case RtspState::MSG_RTSP_START:
case RtspState::MSG_RTSP_DATA:
case RtspState::MSG_RTSP_CLOSE:
case RtspState::MSG_SEND_KA:
case RtspState::MSG_KA_TIMEOUT:
case RtspState::MSG_NEG_TIMEOUT:
CLOGE("NEG timeout.");
if (listener_ != nullptr) {
listener_->OnPeerGone();
}
break;
default:
break;
}
}
} // namespace CastSessionRtsp } // namespace CastSessionRtsp
} // namespace CastEngineService } // namespace CastEngineService
} // namespace CastEngine } // namespace CastEngine

View File

@ -21,7 +21,6 @@
#include <mutex> #include <mutex>
#include "channel.h" #include "channel.h"
#include "handler.h"
#include "message.h" #include "message.h"
#include "rtsp_listener_inner.h" #include "rtsp_listener_inner.h"
#include "cast_engine_common.h" #include "cast_engine_common.h"
@ -30,10 +29,9 @@ namespace OHOS {
namespace CastEngine { namespace CastEngine {
namespace CastEngineService { namespace CastEngineService {
namespace CastSessionRtsp { namespace CastSessionRtsp {
class RtspChannelManager : public Message, public std::enable_shared_from_this<RtspChannelManager> {
class RtspChannelManager : public Handler, public Message {
public: public:
RtspChannelManager(RtspListenerInner *listener, ProtocolType protocolType); RtspChannelManager(std::shared_ptr<RtspListenerInner> listener, ProtocolType protocolType);
~RtspChannelManager(); ~RtspChannelManager();
void OnConnected(ChannelLinkType channelLinkType); void OnConnected(ChannelLinkType channelLinkType);
@ -49,7 +47,6 @@ public:
std::shared_ptr<IChannelListener> GetChannelListener(); std::shared_ptr<IChannelListener> GetChannelListener();
bool SendRtspData(const std::string &request); bool SendRtspData(const std::string &request);
void CfgNegTimeout(bool isClear);
void SetNegAlgorithmId(int algorithmId); void SetNegAlgorithmId(int algorithmId);
private: private:
@ -66,31 +63,29 @@ private:
class ChannelListener : public IChannelListener { class ChannelListener : public IChannelListener {
public: public:
explicit ChannelListener(RtspChannelManager *channelManager) : channelManager_(channelManager) {} explicit ChannelListener(std::shared_ptr<RtspChannelManager> channelManager) : channelManager_(channelManager)
~ChannelListener() {}
{ ~ChannelListener() {}
channelManager_ = nullptr;
}
void OnDataReceived(const uint8_t *buffer, unsigned int length, long timeCost) final; void OnDataReceived(const uint8_t *buffer, unsigned int length, long timeCost) final;
private: private:
RtspChannelManager *channelManager_; std::weak_ptr<RtspChannelManager> channelManager_;
}; };
constexpr static int SESSION_KEY_LENGTH = 16; constexpr static int SESSION_KEY_LENGTH = 16;
bool SendData(const std::string &dataFrame); bool SendData(const std::string &dataFrame);
void HandleMessage(const Message &msg) override;
uint8_t sessionKeys_[SESSION_KEY_LENGTH] = {0}; uint8_t sessionKeys_[SESSION_KEY_LENGTH] = {0};
uint32_t sessionKeyLength_{ 0 }; uint32_t sessionKeyLength_{ 0 };
RtspListenerInner *listener_; std::weak_ptr<RtspListenerInner> listener_;
bool isSessionActive_{ false }; bool isSessionActive_{ false };
std::shared_ptr<Channel> channel_; std::shared_ptr<Channel> channel_;
std::shared_ptr<ChannelListener> channelListener_; std::shared_ptr<ChannelListener> channelListener_;
int algorithmId_{ 0 }; int algorithmId_{ 0 };
ProtocolType protocolType_; ProtocolType protocolType_;
std::mutex mutex_;
}; };
} // namespace CastSessionRtsp } // namespace CastSessionRtsp
} // namespace CastEngineService } // namespace CastEngineService

View File

@ -41,9 +41,6 @@ std::shared_ptr<IRtspController> IRtspController::GetInstance(std::shared_ptr<IR
RtspController::RtspController(std::shared_ptr<IRtspListener> listener, ProtocolType protocolType, EndType endType) RtspController::RtspController(std::shared_ptr<IRtspListener> listener, ProtocolType protocolType, EndType endType)
: protocolType_(protocolType), listener_(listener), endType_(endType) : protocolType_(protocolType), listener_(listener), endType_(endType)
{ {
rtspNetManager_ = std::make_unique<RtspChannelManager>(this, protocolType);
ResponseFuncMapInit();
RequestFuncMapInit();
CLOGI("Out, endType %{public}d", endType); CLOGI("Out, endType %{public}d", endType);
} }
@ -309,19 +306,22 @@ bool RtspController::StopEngine()
return true; return true;
} }
std::string RtspController::ParseCipherItem(const std::string &item) const std::set<std::string> RtspController::ParseCipherItem(const std::string &item) const
{ {
if (item.empty()) { if (item.empty()) {
return ""; return {};
} }
std::set<std::string> supportedCipherLists;
std::vector<std::string> cipherLists; std::vector<std::string> cipherLists;
Utils::SplitString(item, cipherLists, ", "); Utils::SplitString(item, cipherLists, ", ");
for (size_t index = 0; index < cipherLists.size(); index++) { for (size_t index = 0; index < cipherLists.size(); index++) {
if (Utils::ToLower(cipherLists[index]) == EncryptDecrypt::GetInstance().PC_ENCRYPT_ALG) { if (Utils::ToLower(cipherLists[index]) == EncryptDecrypt::CIPHER_AES_CTR_128) {
return EncryptDecrypt::GetInstance().PC_ENCRYPT_ALG; supportedCipherLists.insert(EncryptDecrypt::CIPHER_AES_CTR_128);
} else if (Utils::ToLower(cipherLists[index]) == EncryptDecrypt::CIPHER_AES_GCM_128) {
supportedCipherLists.insert(EncryptDecrypt::CIPHER_AES_GCM_128);
} }
} }
return ""; return supportedCipherLists;
} }
bool RtspController::ProcessAnnounceRequest(RtspParse &request) bool RtspController::ProcessAnnounceRequest(RtspParse &request)
@ -349,18 +349,26 @@ bool RtspController::ProcessAnnounceRequest(RtspParse &request)
int version = instance.GetVersion(); int version = instance.GetVersion();
CLOGD("AuthNeg: Get algStr is %{public}s version %{public}d", encryptStr.c_str(), version); CLOGD("AuthNeg: Get algStr is %{public}s version %{public}d", encryptStr.c_str(), version);
std::string sendStr = ParseCipherItem(encryptStr); std::set<std::string> cipherList = ParseCipherItem(encryptStr);
if (sendStr.empty() && (listener_ != nullptr)) { if (cipherList.empty() && (listener_ != nullptr)) {
listener_->OnError(ERROR_CODE_DEFAULT); CLOGE("cipherList is empty");
return false; return false;
} }
EncryptionParamInfo encryptionParamInfo{}; EncryptionParamInfo encryptionParamInfo{};
encryptionParamInfo.controlChannelAlgId = static_cast<uint32_t>(instance.GetEncryptMatch(sendStr)); encryptionParamInfo.controlChannelAlgId = static_cast<uint32_t>(instance.GetControlEncryptCipher(cipherList));
encryptionParamInfo.dataChannelAlgId = static_cast<uint32_t>(instance.GetEncryptMatch(sendStr)); encryptionParamInfo.dataChannelAlgId = static_cast<uint32_t>(instance.GetMediaEncryptCipher(cipherList));
negotiatedParamInfo_.SetEncryptionParamInfo(encryptionParamInfo); negotiatedParamInfo_.SetEncryptionParamInfo(encryptionParamInfo);
if (endType_ == EndType::CAST_SOURCE) { if (endType_ == EndType::CAST_SOURCE) {
std::string sendStr;
for (auto &cipher : cipherList) {
if (sendStr.empty()) {
sendStr = cipher;
continue;
}
sendStr += ", " + cipher;
}
std::string req = RtspEncap::EncapAnnounce(sendStr, ++currentSeq_, version); std::string req = RtspEncap::EncapAnnounce(sendStr, ++currentSeq_, version);
rtspNetManager_->SendRtspData(req); rtspNetManager_->SendRtspData(req);
waitRsp_ = WaitResponse::WAITING_RSP_ANNOUNCE; waitRsp_ = WaitResponse::WAITING_RSP_ANNOUNCE;

View File

@ -76,7 +76,7 @@ private:
bool SendAction(ActionType type); bool SendAction(ActionType type);
void ProcessSinkDeviceType(const std::string &content); void ProcessSinkDeviceType(const std::string &content);
bool StopEngine(); bool StopEngine();
std::string ParseCipherItem(const std::string &item) const; std::set<std::string> ParseCipherItem(const std::string &item) const;
bool ProcessOptionRequest(RtspParse &request); bool ProcessOptionRequest(RtspParse &request);
bool ProcessSetupRequest(RtspParse &request); bool ProcessSetupRequest(RtspParse &request);
bool ProcessGetParameterRequestM3(RtspParse &request); bool ProcessGetParameterRequestM3(RtspParse &request);

View File

@ -21,6 +21,7 @@
#include <cstdint> #include <cstdint>
#include <string> #include <string>
#include <set>
#include "openssl/hmac.h" #include "openssl/hmac.h"
#include "openssl/err.h" #include "openssl/err.h"
@ -44,14 +45,17 @@ class EncryptDecrypt final {
public: public:
static EncryptDecrypt &GetInstance(); static EncryptDecrypt &GetInstance();
bool EncryptData(int algCode, const uint8_t *key, int keyLen, ConstPacketData inputData, PacketData &outputData); std::unique_ptr<uint8_t[]> EncryptData(int algCode, ConstPacketData sessionKey,
bool DecryptData(int algCode, const uint8_t *key, int keyLen, ConstPacketData inputData, PacketData &outputData); ConstPacketData inputData, int &outLen);
std::unique_ptr<uint8_t[]> DecryptData(int algCode, ConstPacketData sessionKey,
ConstPacketData inputData, int &outLen);
std::string GetEncryptInfo(); std::string GetEncryptInfo();
int GetEncryptMatch(const std::string &encyptInfo); int GetMediaEncryptCipher(const std::set<std::string> &cipherList);
int GetControlEncryptCipher(const std::set<std::string> &cipherList);
int GetVersion(); int GetVersion();
static const int AES_KEY_LEN_128 = 16; static const int AES_KEY_LEN_128 = 16;
static const unsigned int AES_IV_LEN = 16; static const int AES_IV_LEN = 16;
static const int AES_KEY_LEN = 16; static const int AES_KEY_LEN = 16;
static const int AES_KEY_SIZE = 16; static const int AES_KEY_SIZE = 16;
static const int PC_ENCRYPT_LEN = 64; static const int PC_ENCRYPT_LEN = 64;
@ -61,7 +65,8 @@ public:
static const int CTR_CODE = 1; static const int CTR_CODE = 1;
static const int GCM_CODE = 2; static const int GCM_CODE = 2;
static const std::string PC_ENCRYPT_ALG; static const std::string CIPHER_AES_CTR_128;
static const std::string CIPHER_AES_GCM_128;
private: private:
enum ErrorCode : int { enum ErrorCode : int {
@ -97,7 +102,7 @@ private:
SEC_ERR_SETAAD_FAIL, SEC_ERR_SETAAD_FAIL,
}; };
static const int AES_GCM_MAX_IVLEN = 12; static const int AES_GCM_MIN_IVLEN = 12;
static const int AES_GCM_SIV_TAG_LEN = 16; static const int AES_GCM_SIV_TAG_LEN = 16;
static const int UNSIGNED_CHAR_MIN = 0; static const int UNSIGNED_CHAR_MIN = 0;
static const int UNSIGNED_CHAR_MAX = 255; static const int UNSIGNED_CHAR_MAX = 255;
@ -111,8 +116,8 @@ private:
ConstPacketData sessionIV); ConstPacketData sessionIV);
int AES128GCMCheckEncryPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo); int AES128GCMCheckEncryPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
int AES128GCMCheckDecryptPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo); int AES128GCMCheckDecryptPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
int EnctyptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo, EVP_CIPHER_CTX *ctx); int EnctyptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
int DecryptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo, EVP_CIPHER_CTX *ctx); int DecryptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
int AES128GCMEncry(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo); int AES128GCMEncry(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
int AES128GCMDecrypt(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo); int AES128GCMDecrypt(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo);
}; };

View File

@ -27,7 +27,8 @@ namespace CastEngine {
namespace CastEngineService { namespace CastEngineService {
DEFINE_CAST_ENGINE_LABEL("Cast-EncryptDecrypt"); DEFINE_CAST_ENGINE_LABEL("Cast-EncryptDecrypt");
const std::string EncryptDecrypt::PC_ENCRYPT_ALG = "aes128ctr"; const std::string EncryptDecrypt::CIPHER_AES_CTR_128 = "aes128ctr";
const std::string EncryptDecrypt::CIPHER_AES_GCM_128 = "aes128gcm";
EncryptDecrypt::EncryptDecrypt() {} EncryptDecrypt::EncryptDecrypt() {}
@ -159,42 +160,43 @@ int EncryptDecrypt::AES128Decrypt(ConstPacketData inputData, PacketData &outputD
int EncryptDecrypt::AES128GCMCheckEncryPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo) int EncryptDecrypt::AES128GCMCheckEncryPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
{ {
if ((inputData.length < 0) || (encryInfo.aad.length < 0) || (encryInfo.key.length != AES_KEY_LEN_128) || if (encryInfo.key.length != AES_KEY_LEN_128) {
(encryInfo.iv.length < AES_GCM_MAX_IVLEN) || (encryInfo.tag.length < AES_GCM_SIV_TAG_LEN)) {
return SEC_ERR_INVALID_KEY_LEN; return SEC_ERR_INVALID_KEY_LEN;
} }
if ((inputData.data == nullptr) && (inputData.length > 0)) { if (encryInfo.key.data == nullptr) {
return SEC_ERR_INVALID_KEY;
}
if (inputData.data == nullptr || inputData.length <= 0) {
return SEC_ERR_INVALID_PLAIN; return SEC_ERR_INVALID_PLAIN;
} }
if ((encryInfo.aad.data == nullptr) && (encryInfo.aad.length > 0)) { if (encryInfo.iv.length < AES_GCM_MIN_IVLEN) {
return SEC_ERR_INVALID_AAD; return SEC_ERR_INVALID_IV_LEN;
} }
if ((encryInfo.key.data == nullptr) || (encryInfo.iv.data == nullptr) || (encryInfo.tag.data == nullptr)) { if (encryInfo.iv.data == nullptr) {
return SEC_ERR_INVALID_IV; return SEC_ERR_INVALID_IV;
} }
if (outputData.length < inputData.length) { if (outputData.length < inputData.length || outputData.length <= AES_GCM_SIV_TAG_LEN) {
return SEC_ERR_INVALID_DATA_LEN; return SEC_ERR_INVALID_DATA_LEN;
} }
if ((outputData.data == nullptr) && (outputData.length > 0)) { if (outputData.data == nullptr || outputData.length <= 0) {
return SEC_ERR_INVALID_CIPHERTEXT; return SEC_ERR_INVALID_CIPHERTEXT;
} }
return 0; return 0;
} }
int EncryptDecrypt::EnctyptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo, int EncryptDecrypt::EnctyptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
EVP_CIPHER_CTX *ctx)
{ {
int ret = 0; int ret = 0;
int len = 0; int len = 0;
int cipherTextLen = 0; int cipherTextLen = 0;
EVP_CIPHER_CTX *ctx = nullptr;
// Enctypt // Enctypt
do { do {
/* Create and initialise the context */ /* Create and initialise the context */
ctx = EVP_CIPHER_CTX_new(); ctx = EVP_CIPHER_CTX_new();
if (ctx == nullptr) { if (ctx == nullptr) {
ret = SEC_ERR_CREATECIPHER_FAIL; return SEC_ERR_CREATECIPHER_FAIL;
break;
} }
/* Initialise the encryption operation. */ /* Initialise the encryption operation. */
@ -209,17 +211,14 @@ int EncryptDecrypt::EnctyptProcess(ConstPacketData inputData, PacketData &output
break; break;
} }
/* Initialise key and IV */ if (!EVP_CIPHER_CTX_set_key_length(ctx, encryInfo.key.length)) {
if (EVP_EncryptInit_ex(ctx, nullptr, nullptr, encryInfo.key.data, encryInfo.iv.data) != 1) { ret = SEC_ERR_INVALID_KEY_LEN;
ret = SEC_ERR_INVALID_KEY;
break; break;
} }
/* /* Initialise key and IV */
* Provide any AAD data. This can be called zero or more times as required if (EVP_EncryptInit_ex(ctx, nullptr, nullptr, encryInfo.key.data, encryInfo.iv.data) != 1) {
*/ ret = SEC_ERR_INVALID_KEY;
if (EVP_EncryptUpdate(ctx, nullptr, &len, encryInfo.aad.data, encryInfo.aad.length) != 1) {
ret = SEC_ERR_SETAAD_FAIL;
break; break;
} }
@ -246,73 +245,80 @@ int EncryptDecrypt::EnctyptProcess(ConstPacketData inputData, PacketData &output
} }
cipherTextLen += len; cipherTextLen += len;
if (cipherTextLen + AES_GCM_SIV_TAG_LEN < outputData.length) {
ret = SEC_ERR_INVALID_DATA_LEN;
break;
}
/* Get the tag */ /* Get the tag */
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, encryInfo.tag.length, encryInfo.tag.data) != 1) { if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, AES_GCM_SIV_TAG_LEN, outputData.data + cipherTextLen) != 1) {
ret = SEC_ERR_GCMGETTAG_FAIL; ret = SEC_ERR_GCMGETTAG_FAIL;
break; break;
} }
outputData.length = cipherTextLen;
ret = 0;
} while (0);
outputData.length = cipherTextLen + AES_GCM_SIV_TAG_LEN;
} while (0);
EVP_CIPHER_CTX_free(ctx);
ctx = nullptr;
return ret; return ret;
} }
int EncryptDecrypt::AES128GCMEncry(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo) int EncryptDecrypt::AES128GCMEncry(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
{ {
EVP_CIPHER_CTX *ctx = nullptr;
int ret = AES128GCMCheckEncryPara(inputData, outputData, encryInfo); int ret = AES128GCMCheckEncryPara(inputData, outputData, encryInfo);
if (ret != 0) { if (ret != 0) {
return ret; return ret;
} }
ret = EnctyptProcess(inputData, outputData, encryInfo, ctx); ret = EnctyptProcess(inputData, outputData, encryInfo);
if (ctx != nullptr) {
EVP_CIPHER_CTX_free(ctx);
ctx = nullptr;
}
return ret; return ret;
} }
int EncryptDecrypt::AES128GCMCheckDecryptPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo) int EncryptDecrypt::AES128GCMCheckDecryptPara(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
{ {
if ((inputData.length < 0) || (encryInfo.aad.length < 0) || (encryInfo.key.length != AES_KEY_LEN_128) || if (encryInfo.key.length != AES_KEY_LEN_128) {
(encryInfo.iv.length < AES_GCM_MAX_IVLEN) || (encryInfo.tag.length < AES_GCM_SIV_TAG_LEN)) {
return SEC_ERR_INVALID_KEY_LEN; return SEC_ERR_INVALID_KEY_LEN;
} }
if ((inputData.data == nullptr) && (inputData.length > 0)) { if (encryInfo.key.data == nullptr) {
return SEC_ERR_INVALID_CIPHERTEXT; return SEC_ERR_INVALID_KEY;
} }
if ((encryInfo.aad.data == nullptr) && (encryInfo.aad.length > 0)) { if (encryInfo.iv.length < AES_GCM_MIN_IVLEN) {
return SEC_ERR_INVALID_AAD; return SEC_ERR_INVALID_IV_LEN;
} }
if ((encryInfo.key.data == nullptr) || (encryInfo.iv.data == nullptr) || (encryInfo.tag.data == nullptr)) { if (encryInfo.iv.data == nullptr) {
return SEC_ERR_INVALID_IV; return SEC_ERR_INVALID_IV;
} }
if (encryInfo.tag.data == nullptr) {
return SEC_ERR_NULL_PTR;
}
if (encryInfo.tag.length < AES_GCM_SIV_TAG_LEN) {
return SEC_ERR_INVALID_DATA_LEN;
}
if (inputData.data == nullptr || inputData.length <= 0) {
return SEC_ERR_INVALID_CIPHERTEXT;
}
if (outputData.length < inputData.length) { if (outputData.length < inputData.length) {
return SEC_ERR_INVALID_DATA_LEN; return SEC_ERR_INVALID_DATA_LEN;
} }
if ((outputData.data == nullptr) && (outputData.length > 0)) { if (outputData.data == nullptr || outputData.length <= 0) {
return SEC_ERR_INVALID_PLAIN; return SEC_ERR_INVALID_PLAIN;
} }
return 0; return 0;
} }
int EncryptDecrypt::DecryptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo, int EncryptDecrypt::DecryptProcess(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
EVP_CIPHER_CTX *ctx)
{ {
int ret = 0; int ret = 0;
int len = 0; int len = 0;
int plainTextLen = 0; int plainTextLen = 0;
EVP_CIPHER_CTX *ctx = nullptr;
do { do {
/* Create and initialise the context */ /* Create and initialise the context */
ctx = EVP_CIPHER_CTX_new(); ctx = EVP_CIPHER_CTX_new();
if (ctx == nullptr) { if (ctx == nullptr) {
ret = SEC_ERR_CREATECIPHER_FAIL; return SEC_ERR_CREATECIPHER_FAIL;
break;
} }
/* Initialise the decryption operation. */ /* Initialise the decryption operation. */
@ -327,18 +333,15 @@ int EncryptDecrypt::DecryptProcess(ConstPacketData inputData, PacketData &output
break; break;
} }
/* Initialise key and IV */ if (!EVP_CIPHER_CTX_set_key_length(ctx, encryInfo.key.length)) {
if (!EVP_DecryptInit_ex(ctx, nullptr, nullptr, encryInfo.key.data, encryInfo.iv.data)) { CLOGE("key length does not match key algorithm");
ret = SEC_ERR_INVALID_KEY; ret = SEC_ERR_INVALID_KEY_LEN;
break; break;
} }
/* /* Initialise key and IV */
* Provide any AAD data. This can be called zero or more times as if (!EVP_DecryptInit_ex(ctx, nullptr, nullptr, encryInfo.key.data, encryInfo.iv.data)) {
* required ret = SEC_ERR_INVALID_KEY;
*/
if (!EVP_DecryptUpdate(ctx, nullptr, &len, encryInfo.aad.data, encryInfo.aad.length)) {
ret = SEC_ERR_SETAAD_FAIL;
break; break;
} }
@ -368,143 +371,153 @@ int EncryptDecrypt::DecryptProcess(ConstPacketData inputData, PacketData &output
} }
plainTextLen += len; plainTextLen += len;
outputData.length = plainTextLen; outputData.length = plainTextLen;
ret = 0;
} while (0); } while (0);
EVP_CIPHER_CTX_free(ctx);
ctx = nullptr;
return ret; return ret;
} }
int EncryptDecrypt::AES128GCMDecrypt(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo) int EncryptDecrypt::AES128GCMDecrypt(ConstPacketData inputData, PacketData &outputData, EncryptInfo &encryInfo)
{ {
EVP_CIPHER_CTX *ctx = nullptr;
int ret = AES128GCMCheckDecryptPara(inputData, outputData, encryInfo); int ret = AES128GCMCheckDecryptPara(inputData, outputData, encryInfo);
if (ret != 0) { if (ret != 0) {
return ret; return ret;
} }
ret = DecryptProcess(inputData, outputData, encryInfo, ctx); ret = DecryptProcess(inputData, outputData, encryInfo);
if (ctx != nullptr) {
EVP_CIPHER_CTX_free(ctx);
ctx = nullptr;
}
return ret; return ret;
} }
bool EncryptDecrypt::EncryptData(int algCode, const uint8_t *key, int keyLen, ConstPacketData inputData, std::unique_ptr<uint8_t[]> EncryptDecrypt::EncryptData(int algCode, ConstPacketData sessionKey,
PacketData &outputData) ConstPacketData inputData, int &outLen)
{ {
uint8_t sessionIV[AES_KEY_SIZE] = {0}; if (algCode != CTR_CODE && algCode != GCM_CODE) {
if (outputData.data == nullptr) {
CLOGE("outputData is null");
return false;
}
if (algCode != CTR_CODE) {
CLOGE("encrypt not CTR for extension"); CLOGE("encrypt not CTR for extension");
return false; return nullptr;
} }
GetAESIv(sessionIV, AES_KEY_SIZE); uint8_t sessionIV[AES_IV_LEN] = {0};
GetAESIv(sessionIV, AES_IV_LEN);
int encryptDataLen = inputData.length + AES_KEY_SIZE; int encryptDataLen = inputData.length + AES_IV_LEN;
if (algCode == GCM_CODE) {
encryptDataLen += AES_GCM_SIV_TAG_LEN;
}
std::unique_ptr<uint8_t[]> encryptData = std::make_unique<uint8_t[]>(encryptDataLen); std::unique_ptr<uint8_t[]> encryptData = std::make_unique<uint8_t[]>(encryptDataLen);
if (encryptData == nullptr) { if (encryptData == nullptr) {
return false; return nullptr;
} }
errno_t ret = memset_s(encryptData.get(), encryptDataLen, 0, encryptDataLen); errno_t ret = memset_s(encryptData.get(), encryptDataLen, 0, encryptDataLen);
if (ret != 0) { if (ret != 0) {
return false; return nullptr;
} }
PacketData output = { encryptData.get(), encryptDataLen }; PacketData output = { encryptData.get() + AES_IV_LEN, encryptDataLen - AES_IV_LEN };
ConstPacketData sessionKey = { key, keyLen };
ConstPacketData iv = { sessionIV, AES_IV_LEN }; ConstPacketData iv = { sessionIV, AES_IV_LEN };
ret = AES128Encry(inputData, output, sessionKey, iv); EncryptInfo encryInfo;
if (ret != 0 || output.length > inputData.length) { if (algCode == GCM_CODE) {
encryInfo.key = sessionKey;
encryInfo.iv = iv;
ret = AES128GCMEncry(inputData, output, encryInfo);
} else {
ret = AES128Encry(inputData, output, sessionKey, iv);
}
if (ret != 0) {
CLOGE("encrypt error enLen [%u][%u]", ret, output.length); CLOGE("encrypt error enLen [%u][%u]", ret, output.length);
return false; return nullptr;
} }
ret = memcpy_s(outputData.data, AES_KEY_SIZE, sessionIV, AES_KEY_SIZE); ret = memcpy_s(encryptData.get(), AES_IV_LEN, sessionIV, AES_IV_LEN);
if (ret != 0) { if (ret != 0) {
return false; return nullptr;
}
ret = memcpy_s(outputData.data + AES_KEY_SIZE, output.length, output.data, output.length);
if (ret != 0) {
return false;
} }
outputData.length = output.length + AES_KEY_SIZE; outLen = output.length + AES_IV_LEN;
return true; return encryptData;
} }
bool EncryptDecrypt::DecryptData(int algCode, const uint8_t *key, int keyLen, ConstPacketData inputData, std::unique_ptr<uint8_t[]> EncryptDecrypt::DecryptData(int algCode, ConstPacketData sessionKey,
PacketData &outputData) ConstPacketData inputData, int &outLen)
{ {
uint8_t sessionIV[AES_KEY_SIZE] = {0}; uint8_t sessionIV[AES_IV_LEN] = {0};
if (algCode != CTR_CODE) { if (algCode != CTR_CODE && algCode != GCM_CODE) {
CLOGE("decrypt not CTR for extension"); CLOGE("decrypt not CTR for extension");
return false; return nullptr;
} }
int minLength = (algCode == GCM_CODE) ? (AES_IV_LEN + AES_GCM_SIV_TAG_LEN) : AES_IV_LEN;
if (inputData.length <= AES_KEY_SIZE || inputData.data == nullptr || outputData.data == nullptr) { if (inputData.length <= minLength || inputData.data == nullptr) {
CLOGE("decrypt para error"); CLOGE("decrypt para error, length:%{public}d", inputData.length);
return false; return nullptr;
} }
int32_t ret = memcpy_s(sessionIV, AES_KEY_SIZE, inputData.data, AES_KEY_SIZE); int32_t ret = memcpy_s(sessionIV, AES_IV_LEN, inputData.data, AES_KEY_SIZE);
if (ret != 0) { if (ret != 0) {
CLOGE("memcpy_s failed"); CLOGE("memcpy_s failed");
delete[] inputData.data; return nullptr;
return false;
} }
int deLen = inputData.length - AES_KEY_SIZE; int deLen = inputData.length - AES_IV_LEN;
deLen = (algCode == GCM_CODE) ? (deLen - AES_GCM_SIV_TAG_LEN) : deLen;
std::unique_ptr<uint8_t[]> decryptData = std::make_unique<uint8_t[]>(deLen); std::unique_ptr<uint8_t[]> decryptData = std::make_unique<uint8_t[]>(deLen);
if (decryptData == nullptr) { if (decryptData == nullptr) {
CLOGE("create decrypt data memory failed"); CLOGE("create decrypt data memory failed");
return false; return nullptr;
} }
ret = memset_s(decryptData.get(), deLen, 0, deLen); ret = memset_s(decryptData.get(), deLen, 0, deLen);
if (ret != 0) { if (ret != 0) {
CLOGE("memset_s failed"); CLOGE("memset_s failed");
return false; return nullptr;
} }
PacketData output = { decryptData.get(), deLen }; PacketData output = { decryptData.get(), deLen };
ConstPacketData sessionKey = { key, keyLen };
ConstPacketData iv = { sessionIV, AES_IV_LEN }; ConstPacketData iv = { sessionIV, AES_IV_LEN };
ConstPacketData input = { inputData.data + AES_KEY_SIZE, deLen }; ConstPacketData input = { inputData.data + AES_KEY_SIZE, deLen };
ret = AES128Decrypt(input, output, sessionKey, iv); if (algCode == GCM_CODE) {
EncryptInfo encryInfo;
encryInfo.key = sessionKey;
encryInfo.iv = iv;
encryInfo.tag = { const_cast<uint8_t *>(input.data + input.length), AES_GCM_SIV_TAG_LEN };
ret = AES128GCMDecrypt(input, output, encryInfo);
} else {
ret = AES128Decrypt(input, output, sessionKey, iv);
}
if (ret != 0 || output.length != deLen) { if (ret != 0 || output.length != deLen) {
CLOGE("decrypt error and ret[%{public}d] Len[%u] delen[%{public}d]", ret, output.length, deLen); CLOGE("decrypt error and ret[%{public}d] Len[%u] delen[%{public}d]", ret, output.length, deLen);
return false; return nullptr;
} }
ret = memcpy_s(outputData.data, output.length, output.data, output.length); outLen = output.length;
if (ret != 0) {
return false;
}
outputData.length = output.length;
return true; return decryptData;
} }
std::string EncryptDecrypt::GetEncryptInfo() std::string EncryptDecrypt::GetEncryptInfo()
{ {
return PC_ENCRYPT_ALG; return CIPHER_AES_CTR_128;
} }
int EncryptDecrypt::GetEncryptMatch(const std::string &encyptInfo) int EncryptDecrypt::GetMediaEncryptCipher(const std::set<std::string> &cipherList)
{ {
if (encyptInfo.size() >= PC_ENCRYPT_LEN) { if (cipherList.count(CIPHER_AES_CTR_128) != 0) {
return INVALID_CODE;
}
if (encyptInfo == PC_ENCRYPT_ALG) {
return CTR_CODE; return CTR_CODE;
} }
CLOGE("not support the cipherlist");
return INVALID_CODE; return INVALID_CODE;
} }
int EncryptDecrypt::GetControlEncryptCipher(const std::set<std::string> &cipherList)
{
// GCM is preferred, followed by CTR
if (cipherList.count(CIPHER_AES_GCM_128) != 0) {
return GCM_CODE;
}
if (cipherList.count(CIPHER_AES_CTR_128) != 0) {
return CTR_CODE;
}
CLOGE("not support the cipherlist");
return INVALID_CODE;
}
int EncryptDecrypt::GetVersion() int EncryptDecrypt::GetVersion()
{ {
return VERSION; return VERSION;