mirror of
https://gitee.com/openharmony/ai_neural_network_runtime
synced 2025-01-20 22:54:35 +00:00
commit
24d94d5f56
@ -32,6 +32,7 @@ constexpr int32_t NUMBER_CACHE_INFO_MEMBERS = 3;
|
|||||||
constexpr int32_t HEX_UNIT = 16;
|
constexpr int32_t HEX_UNIT = 16;
|
||||||
constexpr char ROOT_DIR_STR = '/';
|
constexpr char ROOT_DIR_STR = '/';
|
||||||
constexpr char DOUBLE_SLASH_STR[] = "//";
|
constexpr char DOUBLE_SLASH_STR[] = "//";
|
||||||
|
constexpr int OPVERSION_SUBSTR_NUM = 2;
|
||||||
|
|
||||||
OH_NN_ReturnCode NNCompiledCache::Save(const std::vector<OHOS::NeuralNetworkRuntime::Buffer>& caches,
|
OH_NN_ReturnCode NNCompiledCache::Save(const std::vector<OHOS::NeuralNetworkRuntime::Buffer>& caches,
|
||||||
const std::string& cacheDir,
|
const std::string& cacheDir,
|
||||||
@ -162,7 +163,7 @@ OH_NN_ReturnCode NNCompiledCache::GenerateCacheFiles(const std::vector<OHOS::Neu
|
|||||||
uint32_t version) const
|
uint32_t version) const
|
||||||
{
|
{
|
||||||
const size_t cacheNumber = caches.size();
|
const size_t cacheNumber = caches.size();
|
||||||
uint32_t cacheSize = NUMBER_CACHE_INFO_MEMBERS + cacheNumber;
|
uint32_t cacheSize = NUMBER_CACHE_INFO_MEMBERS + cacheNumber + 1;
|
||||||
std::unique_ptr<int64_t[]> cacheInfo = CreateUniquePtr<int64_t[]>(cacheSize);
|
std::unique_ptr<int64_t[]> cacheInfo = CreateUniquePtr<int64_t[]>(cacheSize);
|
||||||
if (cacheInfo == nullptr) {
|
if (cacheInfo == nullptr) {
|
||||||
LOGE("[NNCompiledCache] GenerateCacheFiles failed, fail to create cacheInfo instance.");
|
LOGE("[NNCompiledCache] GenerateCacheFiles failed, fail to create cacheInfo instance.");
|
||||||
@ -232,6 +233,17 @@ OH_NN_ReturnCode NNCompiledCache::GenerateCacheModel(const std::vector<OHOS::Neu
|
|||||||
cacheModelStream.close();
|
cacheModelStream.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string currentVersion = "0x00000000";
|
||||||
|
std::string opVersionPath = "/data/data/hiai/version";
|
||||||
|
std::ifstream inf(opVersionPath.c_str());
|
||||||
|
if (inf.is_open()) {
|
||||||
|
getline(inf, currentVersion);
|
||||||
|
}
|
||||||
|
|
||||||
|
int currentOpVersion = std::stoi(currentVersion.substr(OPVERSION_SUBSTR_NUM));
|
||||||
|
*cacheInfoPtr++ = currentOpVersion;
|
||||||
|
inf.close();
|
||||||
|
|
||||||
return OH_NN_SUCCESS;
|
return OH_NN_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,6 +326,11 @@ OH_NN_ReturnCode NNCompiledCache::CheckCacheInfo(NNCompiledCacheInfo& modelCache
|
|||||||
modelCacheInfo.modelCheckSum[i] = static_cast<unsigned short>(modelCheckSum[i]);
|
modelCacheInfo.modelCheckSum[i] = static_cast<unsigned short>(modelCheckSum[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!infoCacheFile.read(reinterpret_cast<char*>(&(modelCacheInfo.opVersion)), sizeof(uint64_t))) {
|
||||||
|
LOGW("[NNCompiledCache] opVersion failed.");
|
||||||
|
}
|
||||||
|
|
||||||
|
infoCacheFile.close();
|
||||||
return OH_NN_SUCCESS;
|
return OH_NN_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,6 +33,7 @@ struct NNCompiledCacheInfo {
|
|||||||
int64_t version{0};
|
int64_t version{0};
|
||||||
int64_t deviceId{0};
|
int64_t deviceId{0};
|
||||||
std::vector<unsigned short> modelCheckSum;
|
std::vector<unsigned short> modelCheckSum;
|
||||||
|
int64_t opVersion{0};
|
||||||
};
|
};
|
||||||
|
|
||||||
class NNCompiledCache {
|
class NNCompiledCache {
|
||||||
|
@ -31,6 +31,7 @@ const int CACHE_INPUT_TENSORDESC_OFFSET = 2;
|
|||||||
const int CACHE_OUTPUT_TENSORDESC_OFFSET = 1;
|
const int CACHE_OUTPUT_TENSORDESC_OFFSET = 1;
|
||||||
constexpr int32_t NUMBER_CACHE_INFO_MEMBERS = 3;
|
constexpr int32_t NUMBER_CACHE_INFO_MEMBERS = 3;
|
||||||
const std::string EXTENSION_KEY_MODEL_NAME = "ModelName";
|
const std::string EXTENSION_KEY_MODEL_NAME = "ModelName";
|
||||||
|
const int OPVERSION_SUBSTR_NUM = 2;
|
||||||
|
|
||||||
struct SerializedTensorDesc {
|
struct SerializedTensorDesc {
|
||||||
public:
|
public:
|
||||||
@ -618,6 +619,16 @@ OH_NN_ReturnCode NNCompiler::RestoreFromCacheFile()
|
|||||||
if (isUpdatable) {
|
if (isUpdatable) {
|
||||||
LOGI("isUpdatable is true");
|
LOGI("isUpdatable is true");
|
||||||
|
|
||||||
|
std::string currentVersion = "0x00000000";
|
||||||
|
std::string path = "/data/data/hiai/version";
|
||||||
|
std::ifstream inf(path.c_str());
|
||||||
|
if (inf.is_open()) {
|
||||||
|
getline(inf, currentVersion);
|
||||||
|
}
|
||||||
|
|
||||||
|
int currentOpVersion = std::stoi(currentVersion.substr(OPVERSION_SUBSTR_NUM));
|
||||||
|
inf.close();
|
||||||
|
|
||||||
NNCompiledCacheInfo modelCacheInfo;
|
NNCompiledCacheInfo modelCacheInfo;
|
||||||
std::string cacheInfoPath = m_cachePath + "/" + m_extensionConfig.modelName + "cache_info.nncache";
|
std::string cacheInfoPath = m_cachePath + "/" + m_extensionConfig.modelName + "cache_info.nncache";
|
||||||
ret = compiledCache.CheckCacheInfo(modelCacheInfo, cacheInfoPath);
|
ret = compiledCache.CheckCacheInfo(modelCacheInfo, cacheInfoPath);
|
||||||
@ -626,31 +637,36 @@ OH_NN_ReturnCode NNCompiler::RestoreFromCacheFile()
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGI("isUpdatable modelCacheInfo");
|
LOGI("isUpdatable currentOpVersion is: %{public}d", currentOpVersion);
|
||||||
|
LOGI("isUpdatable modelCacheInfo opVersion is %{public}d", static_cast<int>(modelCacheInfo.opVersion));
|
||||||
|
|
||||||
const size_t cacheNumber = caches.size();
|
if (currentOpVersion > modelCacheInfo.opVersion) {
|
||||||
uint32_t cacheSize = NUMBER_CACHE_INFO_MEMBERS + cacheNumber;
|
const size_t cacheNumber = caches.size();
|
||||||
uint32_t infoCharNumber = cacheSize * sizeof(int64_t);
|
uint32_t cacheSize = NUMBER_CACHE_INFO_MEMBERS + cacheNumber + 1;
|
||||||
|
uint32_t infoCharNumber = cacheSize * sizeof(int64_t);
|
||||||
|
|
||||||
std::unique_ptr<int64_t[]> cacheInfo = CreateUniquePtr<int64_t[]>(cacheSize);
|
std::unique_ptr<int64_t[]> cacheInfo = CreateUniquePtr<int64_t[]>(cacheSize);
|
||||||
if (cacheInfo == nullptr) {
|
if (cacheInfo == nullptr) {
|
||||||
LOGE("[NNCompiledCache] isUpdatable is true to create unique failed.");
|
LOGE("[NNCompiledCache] isUpdatable is true to create unique failed.");
|
||||||
return OH_NN_MEMORY_ERROR;
|
return OH_NN_MEMORY_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto cacheInfoPtr = cacheInfo.get();
|
auto cacheInfoPtr = cacheInfo.get();
|
||||||
*cacheInfoPtr++ = modelCacheInfo.fileNumber;
|
*cacheInfoPtr++ = modelCacheInfo.fileNumber;
|
||||||
*cacheInfoPtr++ = modelCacheInfo.version - 1;
|
*cacheInfoPtr++ = modelCacheInfo.version - 1;
|
||||||
*cacheInfoPtr++ = modelCacheInfo.deviceId;
|
*cacheInfoPtr++ = modelCacheInfo.deviceId;
|
||||||
|
|
||||||
for (size_t i = 0; i < modelCacheInfo.modelCheckSum.size(); ++i) {
|
for (size_t i = 0; i < modelCacheInfo.modelCheckSum.size(); ++i) {
|
||||||
*cacheInfoPtr++ = static_cast<int64_t>(modelCacheInfo.modelCheckSum[i]);
|
*cacheInfoPtr++ = static_cast<int64_t>(modelCacheInfo.modelCheckSum[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = compiledCache.WriteCacheInfo(infoCharNumber, cacheInfo, m_cachePath);
|
*cacheInfoPtr++ = currentOpVersion;
|
||||||
if (ret != OH_NN_SUCCESS) {
|
|
||||||
LOGE("[NNCompiledCache] isUpdatable is true to write cache info failed.");
|
ret = compiledCache.WriteCacheInfo(infoCharNumber, cacheInfo, m_cachePath);
|
||||||
return ret;
|
if (ret != OH_NN_SUCCESS) {
|
||||||
|
LOGE("[NNCompiledCache] isUpdatable is true to write cache info failed.");
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user