fix cache model size

Signed-off-by: w30052974 <wangyifan94@huawei.com>
This commit is contained in:
w30052974 2024-10-18 10:34:20 +08:00
parent 12e714cee2
commit b734414d46
8 changed files with 63 additions and 5 deletions

View File

@ -56,6 +56,7 @@ struct ExtensionConfig {
std::vector<std::vector<int32_t>> inputDims;
std::vector<std::vector<int32_t>> dynamicDims;
bool isNpuFmShared = false;
bool isExceedRamLimit = false;
};
struct ModelConfig {

View File

@ -510,7 +510,9 @@ OH_NN_ReturnCode CheckExceedRamLimit(const Compilation* compilation, bool& isExc
} else if (compilation->offlineModelPath != nullptr) {
ret = nnrtService.CheckModelSizeFromPath(compilation->offlineModelPath, isExceedRamLimit);
} else if (compilation->cachePath != nullptr) {
ret = nnrtService.CheckModelSizeFromPath(compilation->cachePath, isExceedRamLimit);
std::string modelName;
compilation->compiler->GetModelName(modelName);
ret = nnrtService.CheckModelSizeFromCache(compilation->cachePath, modelName, isExceedRamLimit);
} else if ((compilation->offlineModelBuffer.first != nullptr) && \
(compilation->offlineModelBuffer.second != size_t(0))) {
ret = nnrtService.CheckModelSizeFromBuffer(
@ -532,7 +534,7 @@ OH_NN_ReturnCode CheckExceedRamLimit(const Compilation* compilation, bool& isExc
return OH_NN_SUCCESS;
}
OH_NN_ReturnCode AuthenticateModel(const Compilation* compilation)
OH_NN_ReturnCode AuthenticateModel(const Compilation* compilation, bool &exceedRamLimit)
{
bool isExceedRamLimit = false;
OH_NN_ReturnCode retCode = CheckExceedRamLimit(compilation, isExceedRamLimit);
@ -540,6 +542,7 @@ OH_NN_ReturnCode AuthenticateModel(const Compilation* compilation)
LOGE("AuthenticateModel failed, fail to check if model exceed ram limit.");
return retCode;
}
exceedRamLimit = isExceedRamLimit;
if (!isExceedRamLimit) {
LOGI("Model accupy memory less then limit, no need authenticating.");
@ -582,7 +585,7 @@ OH_NN_ReturnCode AuthenticateModel(const Compilation* compilation)
return OH_NN_SUCCESS;
}
OH_NN_ReturnCode Authentication(Compilation** compilation)
OH_NN_ReturnCode Authentication(Compilation** compilation, bool &exceedRamLimit)
{
if (compilation == nullptr) {
LOGE("Authentication failed, compilation is nullptr.");
@ -601,11 +604,14 @@ OH_NN_ReturnCode Authentication(Compilation** compilation)
return OH_NN_SUCCESS;
}
OH_NN_ReturnCode ret = AuthenticateModel(compilationImpl);
bool isExceedRamLimit = false;
OH_NN_ReturnCode ret = AuthenticateModel(compilationImpl, isExceedRamLimit);
if (ret != OH_NN_SUCCESS) {
LOGE("Authentication failed, fail to authenticate model.");
return ret;
}
// 入参,传出模型大小是否超过限制
exceedRamLimit = isExceedRamLimit;
LOGI("Authentication success.");
return OH_NN_SUCCESS;
@ -732,12 +738,26 @@ NNRT_API OH_NN_ReturnCode OH_NNCompilation_Build(OH_NNCompilation *compilation)
return ret;
}
ret = Authentication(&compilationImpl);
bool isExceedRamLimit = false;
ret = Authentication(&compilationImpl, isExceedRamLimit);
if (ret != OH_NN_SUCCESS) {
LOGE("OH_NNCompilation_Build failed, fail to create compiler.");
return ret;
}
std::unordered_map<std::string, std::vector<char>> configs;
LOGI("[OH_NNCompilation_Build] isExceedRamLimit: %{public}d", static_cast<int>(isExceedRamLimit));
std::vector<char> configContents;
if (isExceedRamLimit) {
configContents.push_back('1');
} else {
configContents.push_back('0');
}
configContents["isExceedRamLimit"] = configContents;
compilationImpl->compiler->SetExtensionConfig(configs);
bool isBuild = compilationImpl->compiler->IsBuild();
if (isBuild) {
LOGE("OH_NNCompilation_Build failed, compilation has been built, don't build again.");

View File

@ -58,6 +58,7 @@ NNRtServiceApi& NNRtServiceApi::GetInstance()
}
LoadFunction(libNNRtService, "CheckModelSizeFromPath", &nnrtService.CheckModelSizeFromPath);
LoadFunction(libNNRtService, "CheckModelSizeFromCache", &nnrtService.CheckModelSizeFromCache);
LoadFunction(libNNRtService, "CheckModelSizeFromBuffer", &nnrtService.CheckModelSizeFromBuffer);
LoadFunction(libNNRtService, "CheckModelSizeFromModel", &nnrtService.CheckModelSizeFromModel);
LoadFunction(libNNRtService, "GetNNRtModelIDFromPath", &nnrtService.GetNNRtModelIDFromPath);

View File

@ -27,6 +27,7 @@ public:
bool IsServiceAvaliable() const;
int (*CheckModelSizeFromPath)(const char* path, bool& exceedLimit) = nullptr;
int (*CheckModelSizeFromCache)(const char* path, const std::string& modelName, bool& exceedLimit) = nullptr;
int (*CheckModelSizeFromBuffer)(const void* buffer, size_t size, bool& exceedLimit) = nullptr;
int (*CheckModelSizeFromModel)(void* model, bool& exceedLimit) = nullptr;
size_t (*GetNNRtModelIDFromPath)(const char*) = nullptr;

View File

@ -41,6 +41,7 @@ const std::string EXTENSION_KEY_OP_LAYOUT = "opLayout";
const std::string EXTENSION_KEY_INPUT_DIMS = "InputDims";
const std::string EXTENSION_KEY_DYNAMIC_DIMS = "DynamicDims";
const std::string EXTENSION_KEY_FM_SHARED = "NPU_FM_SHARED";
const std::string EXTENSION_KEY_IS_EXCEED_RAMLIMIT = "isExceedRamLimit";
const std::string NULL_HARDWARE_NAME = "default";
const std::string HARDWARE_NAME = "const.ai.nnrt_deivce";

View File

@ -41,6 +41,7 @@ OH_NN_ReturnCode NNCompiledCache::Save(const std::vector<OHOS::NeuralNetworkRunt
const std::string& cacheDir,
uint32_t version)
{
LOGI("[NNCompiledCache::Save] m_isExceedRamLimit: %{public}d", static_cast<int>(m_isExceedRamLimit));
if (caches.empty()) {
LOGE("[NNCompiledCache] Save failed, caches is empty.");
return OH_NN_INVALID_PARAMETER;
@ -156,6 +157,11 @@ void NNCompiledCache::SetModelName(const std::string& modelName)
m_modelName = modelName;
}
void NNCompiledCache::SetIsExceedRamLimit(const bool isExceedRamLimit)
{
m_isExceedRamLimit = isExceedRamLimit;
}
OH_NN_ReturnCode NNCompiledCache::GenerateCacheFiles(const std::vector<OHOS::NeuralNetworkRuntime::Buffer>& caches,
const std::string& cacheDir,
uint32_t version) const
@ -245,6 +251,13 @@ OH_NN_ReturnCode NNCompiledCache::GenerateCacheModel(const std::vector<OHOS::Neu
int currentOpVersion = std::stoi(currentVersion.substr(OPVERSION_SUBSTR_NUM));
*cacheInfoPtr++ = currentOpVersion;
LOGI("[NNCompiledCache::GenerateCacheModel] m_isExceedRamLimit: %{public}d", static_cast<int>(m_isExceedRamLimit));
if (m_isExceedRamLimit) {
*cacheInfoPtr++ = 1;
} else {
*cacheInfoPtr++ = 0;
}
return OH_NN_SUCCESS;
}

View File

@ -34,6 +34,7 @@ struct NNCompiledCacheInfo {
int64_t deviceId{0};
std::vector<unsigned short> modelCheckSum;
int64_t opVersion{0};
int64_t isExceedRamLimit{0};
};
class NNCompiledCache {
@ -50,6 +51,7 @@ public:
OH_NN_ReturnCode SetBackend(size_t backendID);
void SetModelName(const std::string& modelName);
void SetIsExceedRamLimit(const bool isExceedRamLimit);
OH_NN_ReturnCode WriteCacheInfo(uint32_t cacheSize,
std::unique_ptr<int64_t[]>& cacheInfo,
const std::string& cacheDir) const;
@ -72,6 +74,7 @@ private:
size_t m_backendID {0};
std::string m_modelName;
std::shared_ptr<Device> m_device {nullptr};
bool m_isExceedRamLimit {false};
};
} // namespace NeuralNetworkRuntime

View File

@ -32,6 +32,7 @@ const int CACHE_OUTPUT_TENSORDESC_OFFSET = 1;
constexpr int32_t NUMBER_CACHE_INFO_MEMBERS = 3;
const std::string EXTENSION_KEY_MODEL_NAME = "ModelName";
const std::string EXTENSION_KEY_FM_SHARED = "NPU_FM_SHARED";
const std::string EXTENSION_KEY_IS_EXCEED_RAMLIMIT = "isExceedRamLimit";
const int OPVERSION_SUBSTR_NUM = 2;
const std::string CURRENT_VERSION = "0x00000000";
const std::string HIAI_VERSION_PATH = "/data/data/hiai/version";
@ -558,6 +559,8 @@ OH_NN_ReturnCode NNCompiler::SaveToCacheFile() const
tensorBuffers.emplace_back(outputTensorDescBuffer);
compiledCache.SetModelName(m_extensionConfig.modelName);
LOGI("[NNCompiler::SaveToCacheFile] m_extensionConfig.isExceedRamLimit: %{public}d", static_cast<int>(m_extensionConfig.isExceedRamLimit));
compiledCache.SetIsExceedRamLimit(m_extensionConfig.isExceedRamLimit);
ret = compiledCache.Save(caches, m_cachePath, m_cacheVersion);
if (ret != OH_NN_SUCCESS) {
LOGE("[NNCompiler] SaveToCacheFile failed, error happened when saving model cache.");
@ -725,6 +728,21 @@ OH_NN_ReturnCode NNCompiler::SetExtensionConfig(const std::unordered_map<std::st
m_extensionConfig.isNpuFmShared = true;
LOGI("[NNCompiler] SetExtensionConfig NpuFmShared enabled.");
}
if (cofigs.find(EXTENSION_KEY_IS_EXCEED_RAMLIMIT) != configs.end()) {
std::vector<char> value = configs.at(EXTENSION_KEY_IS_EXCEED_RAMLIMIT);
if (value.empty()) {
LOGE("[NNCompiler] SetExtensionConfig get empty model name from configs");
return OH_NN_INVALID_PARAMETER;
}
if (value[0] == '1') {
m_extensionConfig.isExceedRamLimit = true;
} else {
m_extensionConfig.isExceedRamLimit = false;
}
LOGI("[NNCompiler] SetExtensionConfig isExceedRamLimit enabled.");
}
return OH_NN_SUCCESS;
}