mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
refactor: optimize the logic for name conversion and the processing of the LoRA model (#955)
This commit is contained in:
@@ -278,6 +278,8 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
model_loader.convert_tensors_name();
|
||||
|
||||
version = model_loader.get_sd_version();
|
||||
if (version == VERSION_COUNT) {
|
||||
LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path));
|
||||
@@ -569,13 +571,13 @@ public:
|
||||
version);
|
||||
}
|
||||
if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) {
|
||||
pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->photo_maker_path, "");
|
||||
pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->photo_maker_path, "", version);
|
||||
if (!pmid_lora->load_from_file(true, n_threads)) {
|
||||
LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path);
|
||||
return false;
|
||||
}
|
||||
LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->photo_maker_path);
|
||||
if (!model_loader.init_from_file(sd_ctx_params->photo_maker_path, "pmid.")) {
|
||||
if (!model_loader.init_from_file_and_convert_name(sd_ctx_params->photo_maker_path, "pmid.")) {
|
||||
LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path);
|
||||
} else {
|
||||
stacked_id = true;
|
||||
@@ -609,7 +611,7 @@ public:
|
||||
ignore_tensors.insert("first_stage_model.");
|
||||
}
|
||||
if (stacked_id) {
|
||||
ignore_tensors.insert("lora.");
|
||||
ignore_tensors.insert("pmid.unet.");
|
||||
}
|
||||
|
||||
if (vae_decode_only) {
|
||||
@@ -925,7 +927,7 @@ public:
|
||||
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
|
||||
return;
|
||||
}
|
||||
LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : "");
|
||||
LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : "", version);
|
||||
if (!lora.load_from_file(false, n_threads)) {
|
||||
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user