refactor: reorganize code and use c api (#133)

This commit is contained in:
leejet
2024-01-01 16:22:18 +08:00
committed by GitHub
parent b139434b57
commit 2e79a82f85
22 changed files with 530311 additions and 49428 deletions

View File

@@ -1,6 +1,7 @@
#include <stdarg.h>
#include <fstream>
#include <regex>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
@@ -1367,6 +1368,72 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
return success;
}
bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
ggml_backend_t backend,
std::set<std::string> ignore_tensors) {
std::set<std::string> tensor_names_in_file;
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
const std::string& name = tensor_storage.name;
tensor_names_in_file.insert(name);
struct ggml_tensor* real;
if (tensors.find(name) != tensors.end()) {
real = tensors[name];
} else {
if (ignore_tensors.find(name) == ignore_tensors.end()) {
LOG_WARN("unknown tensor '%s' in model file", name.c_str());
}
return true;
}
if (
real->ne[0] != tensor_storage.ne[0] ||
real->ne[1] != tensor_storage.ne[1] ||
real->ne[2] != tensor_storage.ne[2] ||
real->ne[3] != tensor_storage.ne[3]) {
LOG_ERROR(
"tensor '%s' has wrong shape in model file: "
"got [%d, %d, %d, %d], expected [%d, %d, %d, %d]",
name.c_str(),
(int)tensor_storage.ne[0], (int)tensor_storage.ne[1], (int)tensor_storage.ne[2], (int)tensor_storage.ne[3],
(int)real->ne[0], (int)real->ne[1], (int)real->ne[2], (int)real->ne[3]);
return false;
}
*dst_tensor = real;
return true;
};
bool success = load_tensors(on_new_tensor_cb, backend);
if (!success) {
LOG_ERROR("load tensors from file failed");
return false;
}
bool some_tensor_not_init = false;
for (auto pair : tensors) {
if (pair.first.find("cond_stage_model.transformer.text_model.encoder.layers.23") != std::string::npos) {
continue;
}
if (pair.first.find("alphas_cumprod") != std::string::npos) {
continue;
}
if (tensor_names_in_file.find(pair.first) == tensor_names_in_file.end()) {
LOG_ERROR("tensor '%s' not in model file", pair.first.c_str());
some_tensor_not_init = true;
}
}
if (some_tensor_not_init) {
return false;
}
return true;
}
int64_t ModelLoader::cal_mem_size(ggml_backend_t backend) {
size_t alignment = 128;
if (backend != NULL) {