fix: allow model and vae using different format

This commit is contained in:
leejet
2023-12-03 17:12:04 +08:00
parent d7af2c2ba9
commit 8a87b273ad
4 changed files with 266 additions and 305 deletions

View File

@@ -3281,9 +3281,10 @@ struct LoraModel {
bool load(ggml_backend_t backend_, std::string file_path) {
backend = backend_;
LOG_INFO("loading LoRA from '%s'", file_path.c_str());
std::shared_ptr<ModelLoader> model_loader = std::shared_ptr<ModelLoader>(init_model_loader_from_file(file_path));
ModelLoader model_loader;
;
if (!model_loader) {
if (!model_loader.init_from_file(file_path)) {
LOG_ERROR("init lora model loader from file failed: '%s'", file_path.c_str());
return false;
}
@@ -3299,10 +3300,10 @@ struct LoraModel {
return false;
}
ggml_type wtype = model_loader->get_sd_wtype();
ggml_type wtype = model_loader.get_sd_wtype();
LOG_DEBUG("calculating buffer size");
int64_t memory_buffer_size = model_loader->cal_mem_size();
int64_t memory_buffer_size = model_loader.cal_mem_size();
LOG_DEBUG("lora params backend buffer size = % 6.2f MB", memory_buffer_size / (1024.0 * 1024.0));
params_buffer_lora = ggml_backend_alloc_buffer(backend, memory_buffer_size);
@@ -3320,7 +3321,7 @@ struct LoraModel {
return true;
};
model_loader->load_tensors(on_new_tensor_cb);
model_loader.load_tensors(on_new_tensor_cb);
LOG_DEBUG("finished loaded lora");
ggml_allocr_free(alloc);
@@ -3664,21 +3665,21 @@ public:
#endif
#endif
LOG_INFO("loading model from '%s'", model_path.c_str());
std::shared_ptr<ModelLoader> model_loader = std::shared_ptr<ModelLoader>(init_model_loader_from_file(model_path));
ModelLoader model_loader;
if (!model_loader) {
if (!model_loader.init_from_file(model_path)) {
LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str());
return false;
}
if (vae_path.size() > 0) {
LOG_INFO("loading vae from '%s'", vae_path.c_str());
if (!model_loader->init_from_file(vae_path, "vae.")) {
if (!model_loader.init_from_file(vae_path, "vae.")) {
LOG_WARN("loading vae from '%s' failed", vae_path.c_str());
}
}
SDVersion version = model_loader->get_sd_version();
SDVersion version = model_loader.get_sd_version();
if (version == VERSION_COUNT) {
LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str());
return false;
@@ -3687,7 +3688,7 @@ public:
diffusion_model = UNetModel(version);
LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]);
if (wtype == GGML_TYPE_COUNT) {
model_data_type = model_loader->get_sd_wtype();
model_data_type = model_loader.get_sd_wtype();
} else {
model_data_type = wtype;
}
@@ -3697,7 +3698,7 @@ public:
auto add_token = [&](const std::string& token, int32_t token_id) {
cond_stage_model.tokenizer.add_token(token, token_id);
};
bool success = model_loader->load_vocab(add_token);
bool success = model_loader.load_vocab(add_token);
if (!success) {
LOG_ERROR("get vocab from file failed: '%s'", model_path.c_str());
return false;
@@ -3794,7 +3795,7 @@ public:
// print_ggml_tensor(alphas_cumprod_tensor);
success = model_loader->load_tensors(on_new_tensor_cb);
success = model_loader.load_tensors(on_new_tensor_cb);
if (!success) {
LOG_ERROR("load tensors from file failed");
ggml_free(ctx);