feat:add encrypt option for raw stream data

Signed-off-by: yueyan <yueyan8@huawei.com>
This commit is contained in:
yueyan 2024-03-11 10:15:38 +08:00
parent 5c0e6fc2ed
commit ce3f0d742e
26 changed files with 1770 additions and 47 deletions

View File

@ -136,7 +136,8 @@
"//foundation/communication/dsoftbus/tests/sdk/transmission:integration_test",
"//foundation/communication/dsoftbus/tests/sdk/frame/common:unittest",
"//foundation/communication/dsoftbus/tests/sdk/frame/standard:unittest",
"//foundation/communication/dsoftbus/tests/sdk/frame:fuzztest"
"//foundation/communication/dsoftbus/tests/sdk/frame:fuzztest",
"//foundation/communication/dsoftbus/tests/sdk/transmission:moduletest"
]
}
}

View File

@ -223,6 +223,9 @@ enum SoftBusErrNo {
SOFTBUS_TRANS_SET_APP_INFO_FAILED,
SOFTBUS_TRANS_NOT_META_SESSION,
SOFTBUS_TRANS_SERVER_INIT_FAILED,
SOFTBUS_TRANS_SESSION_SERVER_NOT_FOUND,
SOFTBUS_TRANS_ENCRYPT_ERR,
SOFTBUS_TRANS_DECRYPT_ERR,
/* errno begin: -((203 << 21) | (3 << 16) | 0xFFFF) */
SOFTBUS_AUTH_ERR_BASE = SOFTBUS_ERRNO(AUTH_SUB_MODULE_CODE),

View File

@ -29,13 +29,14 @@ extern "C" {
* @version 2.0
*/
typedef enum {
DATA_TYPE_MESSAGE = 1, /**< Message */
DATA_TYPE_BYTES, /**< Bytes */
DATA_TYPE_FILE, /**< File */
DATA_TYPE_RAW_STREAM, /**< Raw data stream */
DATA_TYPE_VIDEO_STREAM, /**< Video data stream*/
DATA_TYPE_AUDIO_STREAM, /**< Audio data stream*/
DATA_TYPE_SLICE_STREAM, /**< Video slice stream*/
DATA_TYPE_MESSAGE = 1, /**< Message */
DATA_TYPE_BYTES, /**< Bytes */
DATA_TYPE_FILE, /**< File */
DATA_TYPE_RAW_STREAM, /**< Raw data stream */
DATA_TYPE_VIDEO_STREAM, /**< Video data stream */
DATA_TYPE_AUDIO_STREAM, /**< Audio data stream */
DATA_TYPE_SLICE_STREAM, /**< Video slice stream */
DATA_TYPE_RAW_STREAM_ENCRYPED, /**< Encryped raw stream data */
DATA_TYPE_BUTT,
} TransDataType;

View File

@ -393,6 +393,8 @@
"IsSessionExceedLimit";
"DiscRecoveryPublish";
"DiscRecoverySubscribe";
"ClientRawStreamEncryptOptGet";
"ClientRawStreamEncryptDefOptGet";
extern "C++" {
OHOS::StreamAdaptor*;
Communication::SoftBus*;

View File

@ -35,6 +35,8 @@ typedef struct {
int32_t (*OnQosEvent)(int32_t channelId, int32_t channelType, int32_t eventId,
int32_t tvCount, const QosTv *tvList);
int32_t (*OnIdleTimeoutReset)(int32_t sessionId);
int32_t (*OnRawStreamEncryptDefOptGet)(const char *sessionName, bool *isEncrypt);
int32_t (*OnRawStreamEncryptOptGet)(int32_t channelId, int32_t channelType, bool *isEncrypt);
} IClientSessionCallBack;
IClientSessionCallBack *GetClientSessionCb(void);

View File

@ -65,6 +65,7 @@ typedef struct {
int32_t crc;
LinkType linkType[LINK_TYPE_MAX];
uint32_t dataConfig;
bool isEncyptedRawStream;
} SessionInfo;
typedef struct {
@ -82,6 +83,7 @@ typedef struct {
SessionListenerAdapter listener;
ListNode sessionList;
bool permissionState;
bool isSrvEncryptedRawStream;
} ClientSessionServer;
typedef enum {
@ -161,7 +163,8 @@ int32_t ClientAddSocketServer(SoftBusSecType type, const char *pkgName, const ch
int32_t ClientDeleteSocketSession(int32_t sessionId);
int32_t ClientAddSocketSession(const SessionParam *param, int32_t *sessionId, bool *isEnabled);
int32_t ClientAddSocketSession(const SessionParam *param, bool isEncyptedRawStream, int32_t *sessionId,
bool *isEnabled);
int32_t ClientSetListenerBySessionId(int32_t sessionId, const ISocketListener *listener, bool isServer);
@ -180,6 +183,10 @@ bool IsSessionExceedLimit();
int32_t ClientResetIdleTimeoutById(int32_t sessionId);
int32_t ClientGetSessionNameByChannelId(int32_t channelId, int32_t channelType, char *sessionName, int32_t len);
int32_t ClientRawStreamEncryptDefOptGet(const char *sessionName, bool *isEncrypt);
int32_t ClientRawStreamEncryptOptGet(int32_t channelId, int32_t channelType, bool *isEncrypt);
#ifdef __cplusplus
}
#endif

View File

@ -363,5 +363,7 @@ IClientSessionCallBack *GetClientSessionCb(void)
g_sessionCb.OnGetSessionId = ClientGetSessionIdByChannelId;
g_sessionCb.OnQosEvent = TransOnQosEvent;
g_sessionCb.OnIdleTimeoutReset = ClientResetIdleTimeoutById;
g_sessionCb.OnRawStreamEncryptDefOptGet = ClientRawStreamEncryptDefOptGet;
g_sessionCb.OnRawStreamEncryptOptGet = ClientRawStreamEncryptOptGet;
return &g_sessionCb;
}

View File

@ -314,6 +314,7 @@ static ClientSessionServer *GetNewSessionServer(SoftBusSecType type, const char
goto EXIT_ERR;
}
server->listener.isSocketListener = false;
server->isSrvEncryptedRawStream = false;
ListInit(&server->node);
ListInit(&server->sessionList);
@ -1545,7 +1546,7 @@ static ClientSessionServer *GetNewSocketServer(SoftBusSecType type, const char *
if (strcpy_s(server->sessionName, sizeof(server->sessionName), sessionName) != EOK) {
goto EXIT_ERR;
}
server->isSrvEncryptedRawStream = false;
ListInit(&server->node);
ListInit(&server->sessionList);
return server;
@ -1710,7 +1711,20 @@ static bool IsDistributedDataSession(const char *sessionName)
return true;
}
static SessionInfo *GetSocketExistSession(const SessionParam *param)
static bool IsDifferentDataType(const SessionInfo *sessionInfo, int dataType, bool isEncyptedRawStream)
{
if (sessionInfo->info.flag != dataType) {
return true;
}
if (dataType != RAW_STREAM) {
return false;
}
return sessionInfo->isEncyptedRawStream != isEncyptedRawStream;
}
static SessionInfo *GetSocketExistSession(const SessionParam *param, bool isEncyptedRawStream)
{
ClientSessionServer *serverNode = NULL;
SessionInfo *sessionInfo = NULL;
@ -1721,10 +1735,10 @@ static SessionInfo *GetSocketExistSession(const SessionParam *param)
continue;
}
LIST_FOR_EACH_ENTRY(sessionInfo, &(serverNode->sessionList), SessionInfo, node) {
if ((strcmp(sessionInfo->info.peerSessionName, param->peerSessionName) != 0) ||
if (sessionInfo->isServer || (strcmp(sessionInfo->info.peerSessionName, param->peerSessionName) != 0) ||
(strcmp(sessionInfo->info.peerDeviceId, param->peerDeviceId) != 0) ||
(strcmp(sessionInfo->info.groupId, param->groupId) != 0) ||
(sessionInfo->info.flag != param->attr->dataType)) {
IsDifferentDataType(sessionInfo, param->attr->dataType, isEncyptedRawStream)) {
continue;
}
return sessionInfo;
@ -1783,7 +1797,7 @@ static SessionInfo *CreateNewSocketSession(const SessionParam *param)
return session;
}
int32_t ClientAddSocketSession(const SessionParam *param, int32_t *sessionId, bool *isEnabled)
int32_t ClientAddSocketSession(const SessionParam *param, bool isEncyptedRawStream, int32_t *sessionId, bool *isEnabled)
{
if (param == NULL || param->sessionName == NULL || param->groupId == NULL || param->attr == NULL ||
sessionId == NULL) {
@ -1801,7 +1815,7 @@ int32_t ClientAddSocketSession(const SessionParam *param, int32_t *sessionId, bo
return SOFTBUS_LOCK_ERR;
}
SessionInfo *session = GetSocketExistSession(param);
SessionInfo *session = GetSocketExistSession(param, isEncyptedRawStream);
if (session != NULL) {
*sessionId = session->sessionId;
*isEnabled = session->isEnable;
@ -1815,7 +1829,7 @@ int32_t ClientAddSocketSession(const SessionParam *param, int32_t *sessionId, bo
TRANS_LOGE(TRANS_SDK, "create session failed");
return SOFTBUS_TRANS_SESSION_CREATE_FAILED;
}
session->isEncyptedRawStream = isEncyptedRawStream;
int32_t ret = AddSession(param->sessionName, session);
if (ret != SOFTBUS_OK) {
SoftBusFree(session);
@ -1998,6 +2012,9 @@ int32_t ClientSetSocketState(int32_t socket, uint32_t maxIdleTimeout, SessionRol
if (sessionNode->role == SESSION_ROLE_CLIENT) {
sessionNode->maxIdleTime = maxIdleTimeout;
}
if (sessionNode->role == SESSION_ROLE_SERVER) {
serverNode->isSrvEncryptedRawStream = sessionNode->isEncyptedRawStream;
}
(void)SoftBusMutexUnlock(&(g_clientSessionServerList->lock));
return SOFTBUS_OK;
}
@ -2273,3 +2290,73 @@ int32_t ClientGetSessionNameByChannelId(int32_t channelId, int32_t channelType,
TRANS_LOGE(TRANS_SDK, "not found session with channelId=%{public}d", channelId);
return SOFTBUS_ERR;
}
int32_t ClientRawStreamEncryptDefOptGet(const char *sessionName, bool *isEncrypt)
{
if (sessionName == NULL || isEncrypt == NULL) {
TRANS_LOGE(TRANS_SDK, "Invalid param");
return SOFTBUS_INVALID_PARAM;
}
if (g_clientSessionServerList == NULL) {
TRANS_LOGE(TRANS_SDK, "not init");
return SOFTBUS_TRANS_SESSION_SERVER_NOINIT;
}
if (SoftBusMutexLock(&(g_clientSessionServerList->lock)) != SOFTBUS_OK) {
TRANS_LOGE(TRANS_SDK, "lock failed");
return SOFTBUS_LOCK_ERR;
}
ClientSessionServer *serverNode = NULL;
LIST_FOR_EACH_ENTRY(serverNode, &(g_clientSessionServerList->list), ClientSessionServer, node) {
if (strcmp(serverNode->sessionName, sessionName) == 0) {
*isEncrypt = serverNode->isSrvEncryptedRawStream;
(void)SoftBusMutexUnlock(&g_clientSessionServerList->lock);
return SOFTBUS_OK;
}
}
(void)SoftBusMutexUnlock(&g_clientSessionServerList->lock);
char *tmpName = NULL;
Anonymize(sessionName, &tmpName);
TRANS_LOGE(TRANS_SDK, "not found ClientSessionServer by sessionName=%{public}s", tmpName);
AnonymizeFree(tmpName);
return SOFTBUS_TRANS_SESSION_SERVER_NOT_FOUND;
}
int32_t ClientRawStreamEncryptOptGet(int32_t channelId, int32_t channelType, bool *isEncrypt)
{
if (channelId < 0 || isEncrypt == NULL) {
TRANS_LOGE(TRANS_SDK, "Invalid param");
return SOFTBUS_INVALID_PARAM;
}
if (g_clientSessionServerList == NULL) {
TRANS_LOGE(TRANS_SDK, "not init");
return SOFTBUS_TRANS_SESSION_SERVER_NOINIT;
}
if (SoftBusMutexLock(&(g_clientSessionServerList->lock)) != SOFTBUS_OK) {
TRANS_LOGE(TRANS_SDK, "lock failed");
return SOFTBUS_LOCK_ERR;
}
ClientSessionServer *serverNode = NULL;
SessionInfo *sessionNode = NULL;
SessionInfo *nextSessionNode = NULL;
LIST_FOR_EACH_ENTRY(serverNode, &(g_clientSessionServerList->list), ClientSessionServer, node) {
if (IsListEmpty(&serverNode->sessionList)) {
continue;
}
LIST_FOR_EACH_ENTRY_SAFE(sessionNode, nextSessionNode, &(serverNode->sessionList), SessionInfo, node) {
if (sessionNode->channelId == channelId && sessionNode->channelType == (ChannelType)channelType) {
*isEncrypt = sessionNode->isEncyptedRawStream;
(void)SoftBusMutexUnlock(&g_clientSessionServerList->lock);
return SOFTBUS_OK;
}
}
}
(void)SoftBusMutexUnlock(&g_clientSessionServerList->lock);
TRANS_LOGE(TRANS_SDK, "not found session by channelId=%{public}d", channelId);
return SOFTBUS_TRANS_SESSION_INFO_NOT_FOUND;
}

View File

@ -872,7 +872,7 @@ int CreateSocket(const char *pkgName, const char *sessionName)
return ret;
}
static SessionAttribute *CreateSessionAttributeBySocketInfoTrans(const SocketInfo *info)
static SessionAttribute *CreateSessionAttributeBySocketInfoTrans(const SocketInfo *info, bool *isEncyptedRawStream)
{
SessionAttribute *tmpAttr = (SessionAttribute *)SoftBusCalloc(sizeof(SessionAttribute));
if (tmpAttr == NULL) {
@ -880,6 +880,7 @@ static SessionAttribute *CreateSessionAttributeBySocketInfoTrans(const SocketInf
return NULL;
}
*isEncyptedRawStream = false;
tmpAttr->fastTransData = NULL;
tmpAttr->fastTransDataSize = 0;
switch (info->dataType) {
@ -893,8 +894,10 @@ static SessionAttribute *CreateSessionAttributeBySocketInfoTrans(const SocketInf
tmpAttr->dataType = TYPE_FILE;
break;
case DATA_TYPE_RAW_STREAM:
case DATA_TYPE_RAW_STREAM_ENCRYPED:
tmpAttr->dataType = TYPE_STREAM;
tmpAttr->attr.streamAttr.streamType = RAW_STREAM;
*isEncyptedRawStream = (info->dataType == DATA_TYPE_RAW_STREAM_ENCRYPED);
break;
case DATA_TYPE_VIDEO_STREAM:
tmpAttr->dataType = TYPE_STREAM;
@ -922,10 +925,11 @@ int32_t ClientAddSocket(const SocketInfo *info, int32_t *sessionId)
return SOFTBUS_INVALID_PARAM;
}
SessionAttribute *tmpAttr = CreateSessionAttributeBySocketInfoTrans(info);
bool isEncyptedRawStream = false;
SessionAttribute *tmpAttr = CreateSessionAttributeBySocketInfoTrans(info, &isEncyptedRawStream);
if (tmpAttr == NULL) {
TRANS_LOGE(TRANS_SDK, "Create SessionAttribute failed");
return SOFTBUS_ERR;
return SOFTBUS_MALLOC_ERR;
}
SessionParam param = {
@ -937,7 +941,7 @@ int32_t ClientAddSocket(const SocketInfo *info, int32_t *sessionId)
};
bool isEnabled = false;
int32_t ret = ClientAddSocketSession(&param, sessionId, &isEnabled);
int32_t ret = ClientAddSocketSession(&param, isEncyptedRawStream, sessionId, &isEnabled);
if (ret != SOFTBUS_OK) {
SoftBusFree(tmpAttr);
if (ret == SOFTBUS_TRANS_SESSION_REPEATED) {

View File

@ -35,6 +35,8 @@ typedef struct {
void (*OnUdpChannelClosed)(int32_t channelId, ShutdownReason reason);
void (*OnQosEvent)(int channelId, int eventId, int tvCount, const QosTv *tvList);
int32_t (*OnIdleTimeoutReset)(int32_t sessionId);
int32_t (*OnRawStreamEncryptDefOptGet)(const char *sessionName, bool *isEncrypt);
int32_t (*OnRawStreamEncryptOptGet)(int32_t channelId, bool *isEncrypt);
} UdpChannelMgrCb;
typedef struct {

View File

@ -436,6 +436,41 @@ static int32_t OnIdleTimeoutReset(int32_t sessionId)
return g_sessionCb->OnIdleTimeoutReset(sessionId);
}
static int32_t OnRawStreamEncryptOptGet(int32_t channelId, bool *isEncrypt)
{
if (channelId < 0 || isEncrypt == NULL) {
TRANS_LOGE(TRANS_SDK, "invalid param");
return SOFTBUS_INVALID_PARAM;
}
if (g_sessionCb == NULL) {
TRANS_LOGE(TRANS_SDK, "session callback is null");
return SOFTBUS_ERR;
}
if (g_sessionCb->OnRawStreamEncryptOptGet == NULL) {
TRANS_LOGE(TRANS_SDK, "OnRawStreamEncryptOptGet of session callback is null");
return SOFTBUS_ERR;
}
UdpChannel channel;
if (memset_s(&channel, sizeof(UdpChannel), 0, sizeof(UdpChannel)) != EOK) {
TRANS_LOGE(TRANS_SDK, "on udp channel opened memset failed.");
return SOFTBUS_MEM_ERR;
}
int ret = TransGetUdpChannel(channelId, &channel);
if (ret != SOFTBUS_OK) {
TRANS_LOGE(TRANS_SDK, "get udp failed. channelId=%{public}d", channelId);
return ret;
}
if (channel.info.isServer) {
return g_sessionCb->OnRawStreamEncryptDefOptGet(channel.info.mySessionName, isEncrypt);
} else {
return g_sessionCb->OnRawStreamEncryptOptGet(channel.channelId, CHANNEL_TYPE_UDP, isEncrypt);
}
}
static UdpChannelMgrCb g_udpChannelCb = {
.OnStreamReceived = OnStreamReceived,
.OnFileGetSessionId = OnFileGetSessionId,
@ -444,6 +479,7 @@ static UdpChannelMgrCb g_udpChannelCb = {
.OnUdpChannelClosed = OnUdpChannelClosed,
.OnQosEvent = OnQosEvent,
.OnIdleTimeoutReset = OnIdleTimeoutReset,
.OnRawStreamEncryptOptGet = OnRawStreamEncryptOptGet,
};
int32_t ClientTransUdpMgrInit(IClientSessionCallBack *callback)

View File

@ -59,6 +59,7 @@ typedef struct {
StreamType type;
uint8_t *sessionKey;
uint32_t keyLen;
bool isRawStreamEncrypt;
} VtpStreamOpenParam;
int32_t StartVtpStreamChannelServer(int32_t channelId, const VtpStreamOpenParam *param,

View File

@ -55,6 +55,7 @@ public:
const IStreamListener *callback);
void ReleaseAdaptor();
bool GetAliveState();
bool IsEncryptedRawStream();
private:
int64_t channelId_ = -1;
@ -66,6 +67,7 @@ private:
std::pair<uint8_t*, uint32_t> sessionKey_ = std::make_pair(nullptr, 0);
const IStreamListener *callback_ = nullptr;
std::atomic<bool> enableState_ = {false};
bool isRawStreamEncrypt_ = {false};
};
} // namespace OHOS

View File

@ -63,23 +63,11 @@ public:
retStreamData.bufLen = buflen;
ConvertStreamFrameInfo(&tmpf, stream->GetStreamFrameInfo());
} else if (streamType == StreamType::RAW_STREAM) {
int32_t plainDataLength = buflen - adaptor_->GetEncryptOverhead();
if (plainDataLength < 0) {
TRANS_LOGE(TRANS_STREAM,
"bufLen < GetEncryptOverhead. bufLen=%{public}d, GetEncryptOverhead=%{public}zd",
buflen, adaptor_->GetEncryptOverhead());
int32_t ret = ConvertRawStreamData(retbuf, buflen, plainData, retStreamData);
if (ret != SOFTBUS_OK) {
TRANS_LOGE(TRANS_STREAM, "failed to convert raw stream data, ret=%{public}d", ret);
return;
}
plainData = std::make_unique<char[]>(plainDataLength);
ssize_t decLen = adaptor_->Decrypt(retbuf, buflen, plainData.get(),
plainDataLength, adaptor_->GetSessionKey());
if (decLen != plainDataLength) {
TRANS_LOGE(TRANS_STREAM,
"Decrypt failed, dataLen=%{public}d, decLen=%{public}zd", plainDataLength, decLen);
return;
}
retStreamData.buf = plainData.get();
retStreamData.bufLen = plainDataLength;
} else {
TRANS_LOGE(TRANS_STREAM, "Do not support, streamType=%{public}d", streamType);
return;
@ -139,6 +127,36 @@ public:
}
private:
int32_t ConvertRawStreamData(char *buf, int32_t bufLen, std::unique_ptr<char[]> &plainData,
StreamData &retStreamData)
{
if (!adaptor_->IsEncryptedRawStream()) {
retStreamData.buf = buf;
retStreamData.bufLen = bufLen;
return SOFTBUS_OK;
}
ssize_t encryptOverhead = adaptor_->GetEncryptOverhead();
int32_t plainDataLength = bufLen - encryptOverhead;
if (plainDataLength < 0) {
TRANS_LOGE(TRANS_STREAM,
"bufLen < GetEncryptOverhead. bufLen=%{public}d, GetEncryptOverhead=%{public}zd",
bufLen, encryptOverhead);
return SOFTBUS_TRANS_DECRYPT_ERR;
}
plainData = std::make_unique<char[]>(plainDataLength);
ssize_t decLen = adaptor_->Decrypt(buf, bufLen, plainData.get(), plainDataLength,
adaptor_->GetSessionKey());
if (decLen != plainDataLength) {
TRANS_LOGE(TRANS_STREAM,
"Decrypt failed, dataLen=%{public}d, decLen=%{public}zd", plainDataLength, decLen);
return SOFTBUS_TRANS_DECRYPT_ERR;
}
retStreamData.buf = plainData.get();
retStreamData.bufLen = plainDataLength;
return SOFTBUS_OK;
}
std::shared_ptr<StreamAdaptor> adaptor_ = nullptr;
};
} // namespace OHOS

View File

@ -46,6 +46,29 @@ static inline void ConvertStreamFrameInfo(const StreamFrameInfo *inFrameInfo,
outFrameInfo->bitrate = 0;
}
static int32_t CreateRawStream(const std::shared_ptr<StreamAdaptor> &adaptor, const char *buf, ssize_t bufLen,
std::unique_ptr<IStream> &stream)
{
bool isEncrypt = adaptor->IsEncryptedRawStream();
if (!isEncrypt) {
TRANS_LOGD(TRANS_STREAM, "isEncrypt=%{public}d, bufLen=%{public}zd", isEncrypt, bufLen);
stream = IStream::MakeRawStream(buf, bufLen, {}, Communication::SoftBus::Scene::SOFTBUS_SCENE);
return SOFTBUS_OK;
}
ssize_t dataLen = bufLen + adaptor->GetEncryptOverhead();
TRANS_LOGD(TRANS_STREAM, "isEncrypt=%{public}d, bufLen=%{public}zd, encryptOverhead=%{public}zd", isEncrypt,
bufLen, adaptor->GetEncryptOverhead());
std::unique_ptr<char[]> data = std::make_unique<char[]>(dataLen);
ssize_t encLen = adaptor->Encrypt(buf, bufLen, data.get(), dataLen, adaptor->GetSessionKey());
if (encLen != dataLen) {
TRANS_LOGE(TRANS_STREAM, "encrypted failed, dataLen=%{public}zd, encLen=%{public}zd", dataLen, encLen);
return SOFTBUS_TRANS_ENCRYPT_ERR;
}
stream = IStream::MakeRawStream(data.get(), dataLen, {}, Communication::SoftBus::Scene::SOFTBUS_SCENE);
return SOFTBUS_OK;
}
int32_t SendVtpStream(int32_t channelId, const StreamData *inData, const StreamData *ext, const StreamFrameInfo *param)
{
if (inData == nullptr || inData->buf == nullptr || param == nullptr) {
@ -65,17 +88,11 @@ int32_t SendVtpStream(int32_t channelId, const StreamData *inData, const StreamD
std::unique_ptr<IStream> stream = nullptr;
if (adaptor->GetStreamType() == RAW_STREAM) {
ssize_t dataLen = inData->bufLen + adaptor->GetEncryptOverhead();
TRANS_LOGD(TRANS_STREAM,
"bufLen=%{public}d, encryptOverhead=%{public}zd", inData->bufLen, adaptor->GetEncryptOverhead());
std::unique_ptr<char[]> data = std::make_unique<char[]>(dataLen);
ssize_t encLen = adaptor->Encrypt(inData->buf, inData->bufLen, data.get(), dataLen, adaptor->GetSessionKey());
if (encLen != dataLen) {
TRANS_LOGE(TRANS_STREAM, "encrypted failed, dataLen=%{public}zd, encLen=%{public}zd", dataLen, encLen);
return SOFTBUS_ERR;
int32_t ret = CreateRawStream(adaptor, inData->buf, inData->bufLen, stream);
if (ret != SOFTBUS_OK) {
TRANS_LOGE(TRANS_STREAM, "failed to create raw stream, ret=%{public}d", ret);
return ret;
}
stream = IStream::MakeRawStream(data.get(), dataLen, {}, Communication::SoftBus::Scene::SOFTBUS_SCENE);
} else if (adaptor->GetStreamType() == COMMON_VIDEO_STREAM || adaptor->GetStreamType() == COMMON_AUDIO_STREAM) {
if (inData->bufLen < 0 || inData->bufLen > Communication::SoftBus::MAX_STREAM_LEN ||
(ext != nullptr && (ext->bufLen < 0 || ext->bufLen > Communication::SoftBus::MAX_STREAM_LEN))) {

View File

@ -95,6 +95,7 @@ void StreamAdaptor::InitAdaptor(int32_t channelId, const VtpStreamOpenParam *par
callback_ = callback;
streamType_ = param->type;
channelId_ = channelId;
isRawStreamEncrypt_ = param->isRawStreamEncrypt;
}
void StreamAdaptor::ReleaseAdaptor()
@ -161,3 +162,9 @@ ssize_t StreamAdaptor::Decrypt(const void *in, ssize_t inLen, void *out, ssize_t
return outLen;
}
bool StreamAdaptor::IsEncryptedRawStream()
{
// This option only applies to raw stream data
return isRawStreamEncrypt_;
}

View File

@ -125,6 +125,19 @@ static IStreamListener g_streamCallcb = {
.OnRippleStats = OnRippleStats,
};
static int32_t GetRawStreamEncryptOptByChannelId(int32_t channelId, bool *isEncryptRawStream)
{
if (g_udpChannelMgrCb == NULL) {
TRANS_LOGE(TRANS_STREAM, "udp channel callback is null.");
return SOFTBUS_ERR;
}
if (g_udpChannelMgrCb->OnRawStreamEncryptOptGet == NULL) {
TRANS_LOGE(TRANS_STREAM, "OnRawStreamEncryptOptGet of udp channel callback is null.");
return SOFTBUS_ERR;
}
return g_udpChannelMgrCb->OnRawStreamEncryptOptGet(channelId, isEncryptRawStream);
}
int32_t TransOnstreamChannelOpened(const ChannelInfo *channel, int32_t *streamPort)
{
TRANS_LOGD(TRANS_STREAM, "enter.");
@ -137,13 +150,18 @@ int32_t TransOnstreamChannelOpened(const ChannelInfo *channel, int32_t *streamPo
TRANS_LOGE(TRANS_STREAM, "stream type invalid. type=%{public}d", channel->streamType);
return SOFTBUS_INVALID_PARAM;
}
bool isEncryptedRawStream = false;
if (GetRawStreamEncryptOptByChannelId(channel->channelId, &isEncryptedRawStream) != SOFTBUS_OK) {
TRANS_LOGE(TRANS_STREAM, "failed to get encryption option by channelId=%{public}d", channel->channelId);
return SOFTBUS_ERR;
}
if (channel->isServer) {
if (IsSessionExceedLimit()) {
*streamPort = 0;
return SOFTBUS_TRANS_SESSION_CNT_EXCEEDS_LIMIT;
}
VtpStreamOpenParam p1 = { "DSOFTBUS_STREAM", channel->myIp,
NULL, -1, streamType, (uint8_t*)channel->sessionKey, channel->keyLen };
NULL, -1, streamType, (uint8_t*)channel->sessionKey, channel->keyLen, isEncryptedRawStream};
int32_t port = StartVtpStreamChannelServer(channel->channelId, &p1, &g_streamCallcb);
if (port <= 0) {
@ -154,7 +172,7 @@ int32_t TransOnstreamChannelOpened(const ChannelInfo *channel, int32_t *streamPo
TRANS_LOGI(TRANS_STREAM, "stream server success, listen port=%{public}d.", port);
} else {
VtpStreamOpenParam p1 = { "DSOFTBUS_STREAM", channel->myIp, channel->peerIp,
channel->peerPort, streamType, (uint8_t *)channel->sessionKey, channel->keyLen };
channel->peerPort, streamType, (uint8_t *)channel->sessionKey, channel->keyLen, isEncryptedRawStream};
int32_t ret = StartVtpStreamChannelClient(channel->channelId, &p1, &g_streamCallcb);
if (ret <= 0) {

View File

@ -40,6 +40,11 @@ group("integration_test") {
deps = [ "integration_test:integration_test" ]
}
group("moduletest") {
testonly = true
deps = [ "moduletest:moduletest" ]
}
group("unittest") {
testonly = true
deps = [

View File

@ -0,0 +1,28 @@
# Copyright (c) 2024 Huawei Device Co., Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import("//build/test.gni")
import("../../../../core/common/dfx/dsoftbus_dfx.gni")
import("../../../../dsoftbus.gni")
if (defined(ohos_lite)) {
group("moduletest") {
testonly = true
deps = []
}
} else {
group("moduletest") {
testonly = true
deps = [ "socket/stream_encrypt_test:TransSocketStreamEncryptMt" ]
}
}

View File

@ -0,0 +1,52 @@
# Copyright (c) 2024 Huawei Device Co., Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import("//build/test.gni")
import("../../../../../../dsoftbus.gni")
module_output_path = "dsoftbus/transmission"
ohos_moduletest("TransSocketStreamEncryptMt") {
module_out_path = module_output_path
sources = [
"common.cpp",
"stream_encrypt_client_mt.cpp",
"stream_encrypt_server_mt.cpp",
"tmessenger.cpp",
]
deps = [
"$dsoftbus_root_path/sdk:softbus_client",
"//third_party/googletest:gtest_main",
]
install_enable = false
sanitize = {
cfi = true
cfi_cross_dso = true
debug = true
}
if (is_standard_system) {
external_deps = [
"access_token:libaccesstoken_sdk",
"access_token:libnativetoken",
"access_token:libtoken_setproc",
"c_utils:utils",
"hilog:libhilog",
]
}
part_name = "dsoftbus"
subsystem_name = "communication"
}

View File

@ -0,0 +1,150 @@
/*
* Copyright (c) 2024 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <chrono>
#include <thread>
#include "common.h"
#include "nativetoken_kit.h"
#include "securec.h"
#include "softbus_bus_center.h"
#include "token_setproc.h"
namespace OHOS {
static char g_networkId[NETWORK_ID_BUF_LEN] = { 0 };
static void OnDefNodeOnline(NodeBasicInfo *info)
{
if (info == NULL) {
LOGI("Online: info is null...");
return;
}
(void)strncpy_s(g_networkId, NETWORK_ID_BUF_LEN, info->networkId, NETWORK_ID_BUF_LEN);
LOGI("Online {networkId=%s, deviceName=%s, device type=%d}", info->networkId, info->deviceName, info->deviceTypeId);
}
static void OnDefNodeOffline(NodeBasicInfo *info)
{
if (info == NULL) {
LOGI("Offline: info is null...");
return;
}
LOGI(
"Offline {networkId=%s, deviceName=%s, device type=%d}", info->networkId, info->deviceName, info->deviceTypeId);
}
static void OnDefNodeBasicInfoChanged(NodeBasicInfoType type, NodeBasicInfo *info)
{
if (info == NULL) {
LOGI("InfoChanged: info is null, type=%d", type);
return;
}
LOGI("InfoChanged {networkId=%s, deviceName=%s}", info->networkId, info->deviceName);
}
static void onDefNodeStatusChanged(NodeStatusType type, NodeStatus *status)
{
if (status == NULL) {
LOGI("StatusChanged: info is null, type=%d", type);
return;
}
LOGI("InfoChanged {networkId=%s, authStatus=%d", status->basicInfo.networkId, status->authStatus);
}
static INodeStateCb g_defNodeStateCallback = {
.events = EVENT_NODE_STATE_MASK,
.onNodeOnline = OnDefNodeOnline,
.onNodeOffline = OnDefNodeOffline,
.onNodeBasicInfoChanged = OnDefNodeBasicInfoChanged,
.onNodeStatusChanged = onDefNodeStatusChanged,
};
void AddPermission()
{
uint64_t tokenId;
const char *perms[] = {
OHOS_PERMISSION_DISTRIBUTED_SOFTBUS_CENTER,
OHOS_PERMISSION_DISTRIBUTED_DATASYNC,
};
uint32_t permsSize = sizeof(perms) / sizeof(perms[0]);
NativeTokenInfoParams infoTnstance = {
.dcapsNum = 0,
.permsNum = permsSize,
.aclsNum = 0,
.dcaps = NULL,
.perms = perms,
.acls = NULL,
.processName = "dsoftbus_service",
.aplStr = "system_core",
};
tokenId = GetAccessTokenId(&infoTnstance);
SetSelfTokenID(tokenId);
}
static int CheckRemoteDeviceIsNull(bool isSetNetId)
{
int nodeNum = 0;
NodeBasicInfo *nodeInfo = NULL;
int ret = GetAllNodeDeviceInfo(PKG_NAME, &nodeInfo, &nodeNum);
LOGI("[check]get node number=%d, ret=%d", nodeNum, ret);
if (nodeInfo != NULL && nodeNum > 0) {
LOGI("[check]get netid is=%s", nodeInfo->networkId);
if (isSetNetId) {
(void)strncpy_s(g_networkId, NETWORK_ID_BUF_LEN, nodeInfo->networkId, NETWORK_ID_BUF_LEN);
}
FreeNodeInfo(nodeInfo);
return SOFTBUS_OK;
} else {
LOGI("[check]get nodeInfo is null");
return SOFTBUS_ERR;
}
}
int32_t TestInit()
{
AddPermission();
std::this_thread::sleep_for(std::chrono::seconds(1));
int ret = RegNodeDeviceStateCb(PKG_NAME, &g_defNodeStateCallback);
if (ret != SOFTBUS_OK) {
LOGE("call reg node state callback fail, ret=%d", ret);
return ret;
}
ret = CheckRemoteDeviceIsNull(true);
if (ret != SOFTBUS_OK) {
LOGE("get node fail,please check network, ret=%d", ret);
return ret;
}
return SOFTBUS_OK;
}
int32_t TestDeInit()
{
UnregNodeDeviceStateCb(&g_defNodeStateCallback);
return SOFTBUS_OK;
}
char *WaitOnLineAndGetNetWorkId()
{
while (g_networkId[0] == '\0') {
LOGI("wait online...");
std::this_thread::sleep_for(std::chrono::seconds(1));
}
LOGI("JoinLnn, networkId:%s", g_networkId);
return g_networkId;
}
} // namespace OHOS

View File

@ -0,0 +1,52 @@
/*
* Copyright (c) 2024 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SOCKET_COMMON_H
#define SOCKET_COMMON_H
#include "socket.h"
#include "softbus_common.h"
#include "softbus_error_code.h"
namespace OHOS {
#define LOG(fmt, args...) \
do { \
fprintf(stdout, "" fmt "\n", ##args); \
} while (false)
#define LOGI(fmt, args...) \
do { \
fprintf(stdout, "[INFO][%s:%d]" fmt "\n", __func__, __LINE__, ##args); \
} while (false)
#define LOGE(fmt, args...) \
do { \
fprintf(stdout, "[ERR][%s:%d]" fmt "\n", __func__, __LINE__, ##args); \
} while (false)
inline const char *PKG_NAME = "com.communication.demo";
inline const char *TEST_NOTIFY_NAME = "com.communication.demo.notify.client";
inline const char *TEST_NOTIFY_SRV_NAME = "com.communication.demo.notify.server";
inline const char *TEST_SESSION_NAME = "com.communication.demo.client";
inline const char *TEST_SESSION_NAME_SRV = "com.communication.demo.server";
inline const char *TEST_STREAM_DATA = "EncryptStreamOrUnencryptStreamTest";
int32_t TestInit();
int32_t TestDeInit();
char *WaitOnLineAndGetNetWorkId();
} // namespace OHOS
#endif // SOCKET_COMMON_H

View File

@ -0,0 +1,356 @@
/*
* Copyright (c) 2024 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cinttypes>
#include <map>
#include <gtest/gtest.h>
#include "common.h"
#include "session.h"
#include "tmessenger.h"
#define WAIT_TIMEOUT 5
using namespace testing::ext;
namespace OHOS {
class StreamEncryptClientMt : public testing::Test {
public:
StreamEncryptClientMt() { }
~StreamEncryptClientMt() { }
static void SetUpTestCase(void);
static void TearDownTestCase(void);
void SetUp() override { }
void TearDown() override { }
};
void StreamEncryptClientMt::SetUpTestCase(void)
{
int32_t ret = TestInit();
ASSERT_EQ(ret, SOFTBUS_OK);
ret = TMessenger::GetInstance().Open(PKG_NAME, TEST_NOTIFY_NAME, TEST_NOTIFY_SRV_NAME, false);
ASSERT_EQ(ret, SOFTBUS_OK);
}
void StreamEncryptClientMt::TearDownTestCase(void)
{
int32_t ret = TestDeInit();
ASSERT_EQ(ret, SOFTBUS_OK);
TMessenger::GetInstance().Close();
}
void OnShutdownClient(int32_t socket, ShutdownReason reason)
{
LOGI(">> OnShutdownClient {socket:%d, reason:%d}", socket, reason);
}
static ISocketListener g_listener = {
.OnBind = NULL,
.OnShutdown = OnShutdownClient,
.OnBytes = NULL,
.OnMessage = NULL,
.OnStream = NULL,
.OnFile = NULL,
.OnQos = NULL,
};
bool IsTestOk(bool isLocalEncrypt, const std::string sendData, const std::shared_ptr<Response> &resp)
{
if (resp == nullptr) {
LOGE("the response is null");
return false;
}
bool isPeerEncrypt = resp->isEncrypt_;
std::string recvData = resp->recvData_;
LOGI("isLocalEncrypt:%d, sendData:%s", isLocalEncrypt, sendData.c_str());
LOGI("isPeerEncrypt:%d, recvData:%s", isPeerEncrypt, recvData.c_str());
if (isLocalEncrypt == isPeerEncrypt) {
return sendData == recvData;
} else {
return sendData != recvData;
}
}
/*
* @tc.name: RawStreamEncryptTest001
* @tc.desc: Unencrypted raw stream data transmission test
* @tc.type: FUNC
* @tc.require:
*/
HWTEST_F(StreamEncryptClientMt, RawStreamEncryptTest001, TestSize.Level1)
{
/**
* @tc.steps: step 1. set dataType is DATA_TYPE_RAW_STREAM and create socket by 'Socket' function.
* @tc.expect: socket greater zero.
*/
SocketInfo info = {
.name = (char *)TEST_SESSION_NAME,
.pkgName = (char *)PKG_NAME,
.peerName = (char *)TEST_SESSION_NAME_SRV,
.peerNetworkId = NULL,
.dataType = DATA_TYPE_RAW_STREAM,
};
info.peerNetworkId = WaitOnLineAndGetNetWorkId();
int32_t socket = Socket(info);
ASSERT_GT(socket, 0);
/**
* @tc.steps: step 2. set Qos data and call 'Bind' function.
* @tc.expect: 'Bind' function return SOFTBUS_OK.
*/
QosTV qosInfo[] = {
{.qos = QOS_TYPE_MIN_BW, .value = 80 },
{ .qos = QOS_TYPE_MAX_LATENCY, .value = 4000},
{ .qos = QOS_TYPE_MIN_LATENCY, .value = 2000},
};
int32_t ret = Bind(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &g_listener);
ASSERT_EQ(ret, SOFTBUS_OK);
/**
* @tc.steps: step 3. call 'SendStream' to send unencrypted raw stream data.
* @tc.expect: 'SendStream' function return SOFTBUS_OK.
*/
std::string src = TEST_STREAM_DATA;
StreamData data = {
.buf = (char *)(src.c_str()),
.bufLen = src.size(),
};
StreamData ext = { 0 };
StreamFrameInfo param = { 0 };
ret = SendStream(socket, &data, &ext, &param);
ASSERT_EQ(ret, SOFTBUS_OK);
/**
* @tc.steps: step 4. call 'Wait' function to get test results returned by server side.
* @tc.expect: 'IsTestOk' function return true.
*/
std::shared_ptr<Response> resp = TMessenger::GetInstance().QueryResult(WAIT_TIMEOUT);
bool testResult = IsTestOk(false, TEST_STREAM_DATA, resp);
ASSERT_TRUE(testResult);
Shutdown(socket);
}
/*
* @tc.name: RawStreamEncryptTest002
* @tc.desc: Encrypted raw stream data transmission test
* @tc.type: FUNC
* @tc.require:
*/
HWTEST_F(StreamEncryptClientMt, RawStreamEncryptTest002, TestSize.Level1)
{
/**
* @tc.steps: step 1. set dataType is DATA_TYPE_RAW_STREAM_ENCRYPED and create socket by 'Socket' function.
* @tc.expect: socket greater zero.
*/
SocketInfo info = {
.name = (char *)TEST_SESSION_NAME,
.pkgName = (char *)PKG_NAME,
.peerName = (char *)TEST_SESSION_NAME_SRV,
.peerNetworkId = NULL,
.dataType = DATA_TYPE_RAW_STREAM_ENCRYPED,
};
info.peerNetworkId = WaitOnLineAndGetNetWorkId();
int32_t socket = Socket(info);
ASSERT_GT(socket, 0);
/**
* @tc.steps: step 2. set Qos data and call 'Bind' function.
* @tc.expect: 'Bind' function return SOFTBUS_OK.
*/
QosTV qosInfo[] = {
{.qos = QOS_TYPE_MIN_BW, .value = 80 },
{ .qos = QOS_TYPE_MAX_LATENCY, .value = 4000},
{ .qos = QOS_TYPE_MIN_LATENCY, .value = 2000},
};
int32_t ret = Bind(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &g_listener);
ASSERT_EQ(ret, SOFTBUS_OK);
/**
* @tc.steps: step 3. call 'SendStream' to send encrypted raw stream data.
* @tc.expect: 'SendStream' function return SOFTBUS_OK.
*/
std::string src = TEST_STREAM_DATA;
StreamData data = {
.buf = (char *)(src.c_str()),
.bufLen = src.size(),
};
StreamData ext = { 0 };
StreamFrameInfo param = { 0 };
ret = SendStream(socket, &data, &ext, &param);
ASSERT_EQ(ret, SOFTBUS_OK);
/**
* @tc.steps: step 4. call 'Wait' function to get test results returned by server side.
* @tc.expect: 'IsTestOk' function return true.
*/
std::shared_ptr<Response> resp = TMessenger::GetInstance().QueryResult(WAIT_TIMEOUT);
bool testResult = IsTestOk(true, TEST_STREAM_DATA, resp);
ASSERT_TRUE(testResult);
Shutdown(socket);
}
class SessionStateManager {
public:
static SessionStateManager &GetInstance()
{
static SessionStateManager instance;
return instance;
}
void EnableSessionId(int32_t sessionId)
{
if (sessionId <= 0) {
return;
}
std::unique_lock<std::mutex> lock(sessionIdMutex_);
sessionIdMap_.insert({ sessionId, true });
lock.unlock();
sessionIdCond_.notify_one();
}
void UnenableSessionId(int32_t sessionId)
{
if (sessionId <= 0) {
return;
}
std::unique_lock<std::mutex> lock(sessionIdMutex_);
sessionIdMap_.erase(sessionId);
}
bool WaitEnableSession(int32_t sessionId, uint32_t timeout)
{
bool isEnable = false;
std::unique_lock<std::mutex> lock(sessionIdMutex_);
sessionIdCond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
auto it = sessionIdMap_.find(sessionId);
if (it == sessionIdMap_.end()) {
isEnable = false;
} else {
isEnable = it->second;
}
return isEnable;
});
return isEnable;
}
private:
SessionStateManager() = default;
SessionStateManager(const SessionStateManager &other) = delete;
SessionStateManager(const SessionStateManager &&other) = delete;
SessionStateManager &operator=(const SessionStateManager &other) = delete;
SessionStateManager &operator=(const SessionStateManager &&other) = delete;
std::mutex sessionIdMutex_;
std::condition_variable sessionIdCond_;
std::map<int32_t, bool> sessionIdMap_;
};
static int OnSessionOpened(int sessionId, int result)
{
LOGI(">> OnSessionOpenedServer {sessionId:%d, result=%d", sessionId, result);
if (sessionId <= 0 || result != SOFTBUS_OK) {
LOGE(">> OnSessionOpenedServer, session open failed");
return result;
}
SessionStateManager::GetInstance().EnableSessionId(sessionId);
return SOFTBUS_OK;
}
static void OnSessionClosed(int sessionId)
{
LOGI(">> OnSessionClosedServer {sessionId:%d", sessionId);
SessionStateManager::GetInstance().EnableSessionId(sessionId);
}
/*
* @tc.name: RawStreamEncryptTest003
* @tc.desc: Encrypted raw stream data transmission test
* @tc.type: FUNC
* @tc.require:
*/
HWTEST_F(StreamEncryptClientMt, RawStreamEncryptTest003, TestSize.Level1)
{
/**
* @tc.steps: step 1. call 'CreateSessionServer' function to create session server.
* @tc.expect: 'CreateSessionServer' function return SOFTBUS_OK.
*/
ISessionListener sessionListener = {
.OnSessionOpened = OnSessionOpened,
.OnSessionClosed = OnSessionClosed,
};
int32_t ret = CreateSessionServer(PKG_NAME, TEST_SESSION_NAME, &sessionListener);
ASSERT_EQ(ret, SOFTBUS_OK);
SessionAttribute attr = { 0 };
attr.dataType = TYPE_STREAM;
attr.attr.streamAttr.streamType = RAW_STREAM;
attr.linkTypeNum = 4;
attr.linkType[0] = LINK_TYPE_WIFI_WLAN_5G;
attr.linkType[1] = LINK_TYPE_WIFI_WLAN_2G;
attr.linkType[2] = LINK_TYPE_BR;
attr.linkType[3] = LINK_TYPE_BLE;
attr.fastTransData = NULL;
attr.fastTransDataSize = 0;
/**
* @tc.steps: step 2. call 'OpenSession' function to create session.
* @tc.expect: 'OpenSession' function return SOFTBUS_OK.
*/
int32_t sessionId = OpenSession(TEST_SESSION_NAME, TEST_SESSION_NAME_SRV, WaitOnLineAndGetNetWorkId(), "reserved",
&attr);
ASSERT_GT(sessionId, 0) << "failed to OpenSession, ret=" << sessionId;
/**
* @tc.steps: step 3. call 'WaitEnableSession' function to wait for the session to be opened.
* @tc.expect: 'WaitEnableSession' function return true.
*/
bool isEnable = SessionStateManager::GetInstance().WaitEnableSession(sessionId, 10);
ASSERT_TRUE(isEnable) << "failed to enable session, sessionId" << sessionId;
/**
* @tc.steps: step 4. call 'SendStream' function to send unencrypted raw stream data.
* @tc.expect: 'SendStream' function return SOFTBUS_OK.
*/
std::string src = TEST_STREAM_DATA;
StreamData data = {
.buf = (char *)(src.c_str()),
.bufLen = src.size(),
};
StreamData ext = { 0 };
StreamFrameInfo param = { 0 };
ret = SendStream(sessionId, &data, &ext, &param);
ASSERT_EQ(ret, SOFTBUS_OK);
/**
* @tc.steps: step 5. call 'Wait' function to get test results returned by server side.
* @tc.expect: 'IsTestOk' function return true.
*/
std::shared_ptr<Response> resp = TMessenger::GetInstance().QueryResult(WAIT_TIMEOUT);
bool testResult = IsTestOk(false, TEST_STREAM_DATA, resp);
ASSERT_TRUE(testResult);
CloseSession(sessionId);
RemoveSessionServer(PKG_NAME, TEST_SESSION_NAME);
}
} // namespace OHOS

View File

@ -0,0 +1,361 @@
/*
* Copyright (c) 2024 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cinttypes>
#include <chrono>
#include <thread>
#include <gtest/gtest.h>
#include "common.h"
#include "session.h"
#include "tmessenger.h"
#define SERVER_IDLE_WAIT_TIME 5
using namespace testing::ext;
namespace OHOS {
std::mutex g_recvMutex;
static std::string g_recvData;
class StreamEncryptServerMt : public testing::Test {
public:
StreamEncryptServerMt() { }
~StreamEncryptServerMt() { }
static void SetUpTestCase(void);
static void TearDownTestCase(void);
void SetUp() override { }
void TearDown() override { }
};
void StreamEncryptServerMt::SetUpTestCase(void)
{
int32_t ret = TestInit();
ASSERT_EQ(ret, SOFTBUS_OK);
ret = TMessenger::GetInstance().Open(PKG_NAME, TEST_NOTIFY_SRV_NAME, "", true);
ASSERT_EQ(ret, SOFTBUS_OK);
}
void StreamEncryptServerMt::TearDownTestCase(void)
{
int32_t ret = TestDeInit();
ASSERT_EQ(ret, SOFTBUS_OK);
TMessenger::GetInstance().Close();
}
void OnBindServer(int32_t socket, PeerSocketInfo info)
{
LOGI(">> OnBind {socket:%d, name:%s, networkId:%s, pkgName:%s, dataType:%d}", socket, info.name, info.networkId,
info.pkgName, info.dataType);
}
void OnShutdownServer(int32_t socket, ShutdownReason reason)
{
LOGI(">> OnOnShutdown {socket:%d, reason:%d}", socket, reason);
}
static void OnStreamReceived(int sessionId, const char *testCaseName, const StreamData *data)
{
if (sessionId <= 0) {
LOGI(">> OnStreamReceived, invalid sessionId=%d", sessionId);
return;
}
if (testCaseName == nullptr) {
LOGI(">> OnStreamReceived, testCaseName is nullptr, sessionId=%d", sessionId);
return;
}
if (data == nullptr) {
LOGI(">> OnStreamReceived, data is nullptr, sessionId:%d", sessionId);
return;
}
LOGI(">> OnStreamReceived, sessionId:%d", sessionId);
LOGI(">> OnStreamReceived, testCaseName:%s", testCaseName);
LOGI(">> OnStreamReceived, buf:%s", (data->buf != NULL ? data->buf : "null"));
LOGI(">> OnStreamReceived, bufLen:%d", data->bufLen);
std::lock_guard<std::mutex> lock(g_recvMutex);
g_recvData = std::string((char *)data->buf, data->bufLen);
}
static void OnStreamReceivedWithNoDataType(
int32_t socket, const StreamData *data, const StreamData *ext, const StreamFrameInfo *param)
{
OnStreamReceived(socket, "RawStreamEncryptTestServer001", data);
}
/*
* @tc.name: RawStreamEncryptTestServer001
* @tc.desc: Unencrypted raw stream data transmission test
* @tc.type: FUNC
* @tc.require:
*/
HWTEST_F(StreamEncryptServerMt, RawStreamEncryptTestServer001, TestSize.Level1)
{
/**
* @tc.steps: step 1. do not set dataType and create socket by 'Socket' function.
* @tc.expect: socket greater zero.
*/
SocketInfo info = {
.name = (char *)TEST_SESSION_NAME_SRV,
.pkgName = (char *)PKG_NAME,
};
int32_t socket = Socket(info);
ASSERT_GT(socket, 0);
/**
* @tc.steps: step 2. set Qos data and call 'Listen' function.
* @tc.expect: 'Listen' function return SOFTBUS_OK.
*/
QosTV qosInfo[] = {
{.qos = QOS_TYPE_MIN_BW, .value = 80 },
{ .qos = QOS_TYPE_MAX_LATENCY, .value = 4000},
{ .qos = QOS_TYPE_MIN_LATENCY, .value = 2000},
};
ISocketListener listener = {
.OnBind = OnBindServer,
.OnShutdown = OnShutdownServer,
.OnBytes = NULL,
.OnMessage = NULL,
.OnStream = OnStreamReceivedWithNoDataType,
.OnFile = NULL,
.OnQos = NULL,
};
int32_t ret = Listen(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &listener);
ASSERT_EQ(ret, SOFTBUS_OK);
/**
* @tc.steps: step 3. Register a callback interface for querying.
*/
TMessenger::GetInstance().RegisterOnQuery([] {
std::lock_guard<std::mutex> lock(g_recvMutex);
std::shared_ptr<Response> resp = std::make_shared<Response>(false, g_recvData);
g_recvData.clear();
LOGI("isEcrtypr:%d, recvData:%s", resp->isEncrypt_, resp->recvData_.c_str());
return resp;
});
/**
* @tc.steps: step 4. Waiting for new connections.
*/
while (true) {
LOG("waiting ...");
std::this_thread::sleep_for(std::chrono::seconds(SERVER_IDLE_WAIT_TIME));
}
Shutdown(socket);
}
static void OnStreamReceivedWithUnencryptOpt(
int32_t socket, const StreamData *data, const StreamData *ext, const StreamFrameInfo *param)
{
OnStreamReceived(socket, "RawStreamEncryptTestServer002", data);
}
/*
* @tc.name: RawStreamEncryptTestServer002
* @tc.desc: Unencrypted raw stream data transmission test
* @tc.type: FUNC
* @tc.require:
*/
HWTEST_F(StreamEncryptServerMt, RawStreamEncryptTestServer002, TestSize.Level1)
{
/**
* @tc.steps: step 1. set dataType is DATA_TYPE_RAW_STREAM and create socket by 'Socket' function.
* @tc.expect: socket greater zero.
*/
SocketInfo info = {
.name = (char *)TEST_SESSION_NAME_SRV,
.pkgName = (char *)PKG_NAME,
.dataType = DATA_TYPE_RAW_STREAM,
};
int32_t socket = Socket(info);
ASSERT_GT(socket, 0);
/**
* @tc.steps: step 2. set Qos data and call 'Listen' function.
* @tc.expect: 'Listen' function return SOFTBUS_OK.
*/
QosTV qosInfo[] = {
{.qos = QOS_TYPE_MIN_BW, .value = 80 },
{ .qos = QOS_TYPE_MAX_LATENCY, .value = 4000},
{ .qos = QOS_TYPE_MIN_LATENCY, .value = 2000},
};
ISocketListener listener = {
.OnBind = OnBindServer,
.OnShutdown = OnShutdownServer,
.OnBytes = NULL,
.OnMessage = NULL,
.OnStream = OnStreamReceivedWithUnencryptOpt,
.OnFile = NULL,
.OnQos = NULL,
};
int32_t ret = Listen(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &listener);
ASSERT_EQ(ret, SOFTBUS_OK);
/**
* @tc.steps: step 3. Register a callback interface for querying.
*/
TMessenger::GetInstance().RegisterOnQuery([] {
std::lock_guard<std::mutex> lock(g_recvMutex);
std::shared_ptr<Response> resp = std::make_shared<Response>(false, g_recvData);
g_recvData.clear();
LOGI("isEcrtypr:%d, recvData:%s", resp->isEncrypt_, resp->recvData_.c_str());
return resp;
});
/**
* @tc.steps: step 4. Waiting for new connections.
*/
while (true) {
LOG("waiting ...");
std::this_thread::sleep_for(std::chrono::seconds(SERVER_IDLE_WAIT_TIME));
}
Shutdown(socket);
}
static void OnStreamReceivedWithEncryptOpt(
int32_t socket, const StreamData *data, const StreamData *ext, const StreamFrameInfo *param)
{
OnStreamReceived(socket, "RawStreamEncryptTestServer003", data);
}
/*
* @tc.name: RawStreamEncryptTestServer003
* @tc.desc: Unencrypted raw stream data transmission test
* @tc.type: FUNC
* @tc.require:
*/
HWTEST_F(StreamEncryptServerMt, RawStreamEncryptTestServer003, TestSize.Level1)
{
/**
* @tc.steps: step 1. set dataType is DATA_TYPE_RAW_STREAM and create socket by 'Socket' function.
* @tc.expect: socket greater zero.
*/
SocketInfo info = {
.name = (char *)TEST_SESSION_NAME_SRV,
.pkgName = (char *)PKG_NAME,
.dataType = DATA_TYPE_RAW_STREAM_ENCRYPED,
};
int32_t socket = Socket(info);
ASSERT_GT(socket, 0);
/**
* @tc.steps: step 2. set Qos data and call 'Listen' function.
* @tc.expect: 'Listen' function return SOFTBUS_OK.
*/
QosTV qosInfo[] = {
{.qos = QOS_TYPE_MIN_BW, .value = 80 },
{ .qos = QOS_TYPE_MAX_LATENCY, .value = 4000},
{ .qos = QOS_TYPE_MIN_LATENCY, .value = 2000},
};
ISocketListener listener = {
.OnBind = OnBindServer,
.OnShutdown = OnShutdownServer,
.OnBytes = NULL,
.OnMessage = NULL,
.OnStream = OnStreamReceivedWithEncryptOpt,
.OnFile = NULL,
.OnQos = NULL,
};
int32_t ret = Listen(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &listener);
ASSERT_EQ(ret, SOFTBUS_OK);
/**
* @tc.steps: step 3. Register a callback interface for querying.
*/
TMessenger::GetInstance().RegisterOnQuery([] {
std::lock_guard<std::mutex> lock(g_recvMutex);
std::shared_ptr<Response> resp = std::make_shared<Response>(true, g_recvData);
g_recvData.clear();
LOGI("isEcrtypr:%d, recvData:%s", resp->isEncrypt_, resp->recvData_.c_str());
return resp;
});
/**
* @tc.steps: step 4. Waiting for new connections.
*/
while (true) {
LOG("waiting ...");
std::this_thread::sleep_for(std::chrono::seconds(SERVER_IDLE_WAIT_TIME));
}
Shutdown(socket);
}
static int OnSessionOpenedServer(int sessionId, int result)
{
LOGI(">> OnSessionOpenedServer {sessionId:%d, result=%d", sessionId, result);
if (sessionId <= 0 || result != SOFTBUS_OK) {
return result;
}
return SOFTBUS_OK;
}
static void OnSessionClosedServer(int sessionId)
{
LOGI(">> OnSessionClosedServer {sessionId:%d", sessionId);
}
static void OnStreamReceivedWithOldInterface(
int32_t socket, const StreamData *data, const StreamData *ext, const StreamFrameInfo *param)
{
OnStreamReceived(socket, "RawStreamEncryptTestServer004", data);
}
/*
* @tc.name: RawStreamEncryptTestServer004
* @tc.desc: Use old interace as the server side.
* @tc.type: FUNC
* @tc.require:
*/
HWTEST_F(StreamEncryptServerMt, RawStreamEncryptTestServer004, TestSize.Level1)
{
/**
* @tc.steps: step 1. call 'CreateSessionServer' function to start server.
* @tc.expect: return value is SOFTBUS_OK.
*/
ISessionListener sessionListener = {
.OnSessionOpened = OnSessionOpenedServer,
.OnSessionClosed = OnSessionClosedServer,
.OnStreamReceived = OnStreamReceivedWithOldInterface,
};
int32_t ret = CreateSessionServer(PKG_NAME, TEST_SESSION_NAME_SRV, &sessionListener);
ASSERT_EQ(ret, SOFTBUS_OK);
/**
* @tc.steps: step 2. Register a callback interface for querying.
*/
TMessenger::GetInstance().RegisterOnQuery([] {
std::lock_guard<std::mutex> lock(g_recvMutex);
std::shared_ptr<Response> resp = std::make_shared<Response>(false, g_recvData);
g_recvData.clear();
LOGI("isEcrtypr:%d, recvData:%s", resp->isEncrypt_, resp->recvData_.c_str());
return resp;
});
/**
* @tc.steps: step 3. Waiting for new connections.
*/
while (true) {
LOG("waiting ...");
std::this_thread::sleep_for(std::chrono::seconds(SERVER_IDLE_WAIT_TIME));
}
RemoveSessionServer(PKG_NAME, TEST_SESSION_NAME_SRV);
}
} // namespace OHOS

View File

@ -0,0 +1,378 @@
/*
* Copyright (c) 2024 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <cinttypes>
#include <chrono>
#include <thread>
#include "common.h"
#include "tmessenger.h"
namespace OHOS {
std::string Request::Encode() const
{
return std::to_string(static_cast<int32_t>(cmd_));
}
std::shared_ptr<Request> Request::Decode(const std::string &data)
{
if (data.empty()) {
LOGE("the data is empty");
return nullptr;
}
Cmd cmd = static_cast<Cmd>(std::stoi(data));
if (cmd < Cmd::QUERY_RESULT || cmd > Cmd::QUERY_RESULT) {
LOGE("invalid cmd=%d", static_cast<int32_t>(cmd));
return nullptr;
}
return std::make_shared<Request>(cmd);
}
std::string Response::Encode() const
{
std::string data = std::to_string(isEncrypt_ ? 1 : 0);
return data + SEPARATOR + recvData_;
}
std::shared_ptr<Response> Response::Decode(const std::string &data)
{
if (data.empty()) {
LOGE("the data is empty");
return nullptr;
}
size_t pos = data.find(SEPARATOR);
if (pos == std::string::npos) {
LOGE("can not find separator in the string data");
return nullptr;
}
int32_t isEncryptVal = static_cast<int32_t>(std::stoi(data.substr(0, pos)));
bool isEncrypt = (isEncryptVal == 1);
std::string recvData = data.substr(pos + 1);
return std::make_shared<Response>(isEncrypt, recvData);
}
Message::~Message()
{
if (msgType_ == MsgType::MSG_SEQ && request != nullptr) {
delete request;
}
if (msgType_ == MsgType::MSG_RSP && response != nullptr) {
delete response;
}
}
std::string Message::Encode() const
{
std::string data = std::to_string(static_cast<int32_t>(msgType_));
switch (msgType_) {
case MsgType::MSG_SEQ:
return request == nullptr ? "" : data + SEPARATOR + request->Encode();
case MsgType::MSG_RSP:
return response == nullptr ? "" : data + SEPARATOR + response->Encode();
default:
LOGE("invalid msgType=%d", static_cast<int32_t>(msgType_));
return "";
}
}
std::shared_ptr<Message> Message::Decode(const std::string &data)
{
size_t pos = data.find(SEPARATOR);
if (pos == std::string::npos) {
return nullptr;
}
MsgType msgType = static_cast<MsgType>(std::stoi(data.substr(0, pos)));
switch (msgType) {
case MsgType::MSG_SEQ: {
std::shared_ptr<Request> req = Request::Decode(data.substr(pos + 1));
if (req == nullptr) {
return nullptr;
}
return std::make_shared<Message>(*req);
}
case MsgType::MSG_RSP: {
std::shared_ptr<Response> rsp = Response::Decode(data.substr(pos + 1));
if (rsp == nullptr) {
return nullptr;
}
return std::make_shared<Message>(*rsp);
}
default:
LOGE("invalid msgType=%d", static_cast<int32_t>(msgType));
return nullptr;
}
}
int32_t TMessenger::Open(
const std::string &pkgName, const std::string &myName, const std::string &peerName, bool isServer)
{
isServer_ = isServer;
return isServer_ ? StartListen(pkgName, myName) : StartConnect(pkgName, myName, peerName);
}
void TMessenger::Close()
{
if (socket_ > 0) {
Shutdown(socket_);
socket_ = -1;
}
if (listenSocket_ > 0) {
Shutdown(listenSocket_);
listenSocket_ = -1;
}
pkgName_.clear();
myName_.clear();
peerName_.clear();
peerNetworkId_.clear();
msgList_.clear();
}
int32_t TMessenger::StartListen(const std::string &pkgName, const std::string &myName)
{
if (listenSocket_ > 0) {
return SOFTBUS_OK;
}
SocketInfo info = {
.pkgName = (char *)(pkgName.c_str()),
.name = (char *)(myName.c_str()),
};
int32_t socket = Socket(info);
if (socket <= 0) {
LOGE("failed to create socket, ret=%d", socket);
return SOFTBUS_ERR;
}
LOGI("create listen socket=%d", socket);
QosTV qosInfo[] = {
{.qos = QOS_TYPE_MIN_BW, .value = 80 },
{ .qos = QOS_TYPE_MAX_LATENCY, .value = 4000},
{ .qos = QOS_TYPE_MIN_LATENCY, .value = 2000},
};
static ISocketListener listener = {
.OnBind = TMessenger::OnBind,
.OnMessage = TMessenger::OnMessage,
.OnShutdown = TMessenger::OnShutdown,
};
int32_t ret = Listen(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &listener);
if (ret != SOFTBUS_OK) {
LOGE("failed to listen, socket=%d", socket);
Shutdown(socket);
return ret;
}
listenSocket_ = socket;
pkgName_ = pkgName;
myName_ = myName;
return SOFTBUS_OK;
}
int32_t TMessenger::StartConnect(const std::string &pkgName, const std::string &myName, const std::string &peerName)
{
if (socket_ > 0) {
return SOFTBUS_OK;
}
SocketInfo info = {
.pkgName = (char *)(pkgName.c_str()),
.name = (char *)(myName.c_str()),
.peerName = (char *)(peerName.c_str()),
.peerNetworkId = NULL,
.dataType = DATA_TYPE_MESSAGE,
};
info.peerNetworkId = OHOS::WaitOnLineAndGetNetWorkId();
int32_t socket = Socket(info);
if (socket <= 0) {
LOGE("failed to create socket, ret=%d", socket);
return socket;
}
LOGI("create bind socket=%d", socket);
QosTV qosInfo[] = {
{.qos = QOS_TYPE_MIN_BW, .value = 80 },
{ .qos = QOS_TYPE_MAX_LATENCY, .value = 4000},
{ .qos = QOS_TYPE_MIN_LATENCY, .value = 2000},
};
static ISocketListener listener = {
.OnMessage = OnMessage,
.OnShutdown = OnShutdown,
};
int32_t ret = Bind(socket, qosInfo, sizeof(qosInfo) / sizeof(qosInfo[0]), &listener);
if (ret != SOFTBUS_OK) {
LOGE("failed to bind, socket=%d, ret=%d", socket, ret);
Shutdown(socket);
return ret;
}
pkgName_ = pkgName;
myName_ = myName;
peerNetworkId_ = info.peerNetworkId;
peerName_ = peerName;
socket_ = socket;
return SOFTBUS_OK;
}
void TMessenger::OnBind(int32_t socket, PeerSocketInfo info)
{
TMessenger::GetInstance().SetConnectSocket(socket, info);
}
void TMessenger::OnMessage(int32_t socket, const void *data, uint32_t dataLen)
{
std::string result((char *)data, dataLen);
TMessenger::GetInstance().OnMessageRecv(result);
}
void TMessenger::OnShutdown(int32_t socket, ShutdownReason reason)
{
TMessenger::GetInstance().CloseSocket(socket);
}
void TMessenger::SetConnectSocket(int32_t socket, PeerSocketInfo info)
{
if (socket_ > 0) {
return;
}
socket_ = socket;
peerName_ = info.name;
peerNetworkId_ = info.networkId;
}
void TMessenger::OnMessageRecv(const std::string &result)
{
std::shared_ptr<Message> msg = Message::Decode(result);
if (msg == nullptr) {
LOGE("receive invalid message");
return;
}
switch (msg->msgType_) {
case Message::MsgType::MSG_SEQ: {
OnRequest();
break;
}
case Message::MsgType::MSG_RSP: {
std::unique_lock<std::mutex> lock(recvMutex_);
msgList_.push_back(msg);
lock.unlock();
recvCond_.notify_one();
break;
}
default:
break;
}
}
void TMessenger::OnRequest()
{
std::thread t([&] {
std::shared_ptr<Response> resp = onQuery_();
Message msg { *resp };
int ret = Send(msg);
if (ret != SOFTBUS_OK) {
LOGE("failed to send response");
}
});
t.detach();
}
void TMessenger::CloseSocket(int32_t socket)
{
if (socket_ == socket) {
Shutdown(socket_);
socket_ = -1;
}
}
std::shared_ptr<Response> TMessenger::QueryResult(uint32_t timeout)
{
Request req { Request::Cmd::QUERY_RESULT };
Message msg { req };
int32_t ret = Send(msg);
if (ret != SOFTBUS_OK) {
LOGE("failed to query result, ret=%d", ret);
return nullptr;
}
return WaitResponse(timeout);
}
int32_t TMessenger::Send(const Message &msg)
{
std::string data = msg.Encode();
if (data.empty()) {
LOGE("the data is empty");
return SOFTBUS_ERR;
}
int32_t ret = SendMessage(socket_, data.c_str(), data.size());
if (ret != SOFTBUS_OK) {
LOGE("failed to send message, socket=%d, ret=%d", socket_, ret);
}
return ret;
}
std::shared_ptr<Response> TMessenger::WaitResponse(uint32_t timeout)
{
std::unique_lock<std::mutex> lock(recvMutex_);
std::shared_ptr<Response> rsp = nullptr;
if (recvCond_.wait_for(lock, std::chrono::seconds(timeout), [&] {
rsp = GetMessageFromRecvList(Message::MsgType::MSG_RSP);
return rsp != nullptr;
})) {
return rsp;
}
LOGE("no result received");
return nullptr;
}
std::shared_ptr<Response> TMessenger::GetMessageFromRecvList(Message::MsgType type)
{
auto it = std::find_if(msgList_.begin(), msgList_.end(), [&] (const std::shared_ptr<Message> &it) {
return it->msgType_ == type;
});
if (it == msgList_.end() || *it == nullptr) {
return nullptr;
}
const Response *rsp = (*it)->response;
if (rsp == nullptr) {
msgList_.erase(it);
return nullptr;
}
std::shared_ptr<Response> resp = std::make_shared<Response>(*rsp);
msgList_.erase(it);
return resp;
}
void TMessenger::RegisterOnQuery(TMessenger::OnQueryCallback callback)
{
onQuery_ = callback;
}
} // namespace OHOS

View File

@ -0,0 +1,131 @@
/*
* Copyright (c) 2024 Huawei Device Co., Ltd.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TMESSENGER_H
#define TMESSENGER_H
#include <condition_variable>
#include <functional>
#include <list>
#include <memory>
#include <mutex>
#include <string>
#include "common.h"
#include "socket.h"
#define SEPARATOR "|"
namespace OHOS {
class Request {
public:
enum class Cmd {
QUERY_RESULT,
};
explicit Request(Request::Cmd cmd) : cmd_(cmd) { }
std::string Encode() const;
static std::shared_ptr<Request> Decode(const std::string &data);
Cmd cmd_;
};
class Response {
public:
Response(bool isEncrypt, const std::string &recvData) : isEncrypt_(isEncrypt), recvData_(recvData) { }
std::string Encode() const;
static std::shared_ptr<Response> Decode(const std::string &data);
bool isEncrypt_;
std::string recvData_;
};
class Message {
public:
enum class MsgType : int32_t {
MSG_SEQ,
MSG_RSP,
};
explicit Message(const Request &req) : msgType_(MsgType::MSG_SEQ), request(new Request(req)) { }
explicit Message(const Response &rsp) : msgType_(MsgType::MSG_RSP), response(new Response(rsp)) { }
~Message();
std::string Encode() const;
static std::shared_ptr<Message> Decode(const std::string &data);
MsgType msgType_;
union {
Request *request;
Response *response;
};
};
// class 'TMessenger' is used to notify test result
class TMessenger {
public:
static TMessenger &GetInstance()
{
static TMessenger instance;
return instance;
}
// Start a client or server
int32_t Open(const std::string &pkgName, const std::string &myName, const std::string &peerName, bool isServer);
void Close();
std::shared_ptr<Response> QueryResult(uint32_t timeout);
using OnQueryCallback = std::function<std::shared_ptr<Response>(void)>;
void RegisterOnQuery(OnQueryCallback callback);
private:
TMessenger() = default;
TMessenger(const TMessenger &other) = delete;
TMessenger(const TMessenger &&other) = delete;
TMessenger &operator=(const TMessenger &other) = delete;
TMessenger &operator=(const TMessenger &&other) = delete;
int32_t StartListen(const std::string &pkgName, const std::string &myName);
int32_t StartConnect(const std::string &pkgName, const std::string &myName, const std::string &peerName);
static void OnBind(int32_t socket, PeerSocketInfo info);
static void OnMessage(int32_t socket, const void *data, uint32_t dataLen);
static void OnShutdown(int32_t socket, ShutdownReason reason);
void SetConnectSocket(int32_t socket, PeerSocketInfo info);
void OnMessageRecv(const std::string &result);
void OnRequest();
void CloseSocket(int32_t socket);
int32_t Send(const Message &msg);
std::shared_ptr<Response> WaitResponse(uint32_t timeout);
std::shared_ptr<Response> GetMessageFromRecvList(Message::MsgType type);
std::string pkgName_ { "" };
std::string myName_ { "" };
std::string peerNetworkId_ { "" };
std::string peerName_ { "" };
bool isServer_ { false }; // Indicates the instance is a client or server.
int32_t listenSocket_ { -1 }; // Used to listen the connection from client side.
int32_t socket_ { -1 }; // Indicates the client socket.
std::mutex recvMutex_;
std::condition_variable recvCond_;
std::list<std::shared_ptr<Message>> msgList_;
OnQueryCallback onQuery_;
};
} // namespace OHOS
#endif // TMESSENGER_H