refactor: optimize the logic for name conversion and the processing of the LoRA model (#955)

This commit is contained in:
leejet
2025-11-10 00:12:20 +08:00
committed by GitHub
parent 8ecdf053ac
commit 694f0d9235
20 changed files with 1670 additions and 1415 deletions

View File

@@ -278,6 +278,8 @@ public:
}
}
model_loader.convert_tensors_name();
version = model_loader.get_sd_version();
if (version == VERSION_COUNT) {
LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path));
@@ -569,13 +571,13 @@ public:
version);
}
if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) {
pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->photo_maker_path, "");
pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->photo_maker_path, "", version);
if (!pmid_lora->load_from_file(true, n_threads)) {
LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path);
return false;
}
LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->photo_maker_path);
if (!model_loader.init_from_file(sd_ctx_params->photo_maker_path, "pmid.")) {
if (!model_loader.init_from_file_and_convert_name(sd_ctx_params->photo_maker_path, "pmid.")) {
LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path);
} else {
stacked_id = true;
@@ -609,7 +611,7 @@ public:
ignore_tensors.insert("first_stage_model.");
}
if (stacked_id) {
ignore_tensors.insert("lora.");
ignore_tensors.insert("pmid.unet.");
}
if (vae_decode_only) {
@@ -925,7 +927,7 @@ public:
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
return;
}
LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : "");
LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : "", version);
if (!lora.load_from_file(false, n_threads)) {
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
return;