mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
refactor: simplify the model loading logic (#933)
* remove String2GGMLType * remove preprocess_tensor * fix clip init * simplify the logic for reading weights
This commit is contained in:
@@ -213,7 +213,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
bool is_unet = model_loader.model_is_unet();
|
||||
bool is_unet = sd_version_is_unet(model_loader.get_sd_version());
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) {
|
||||
LOG_INFO("loading clip_l from '%s'", sd_ctx_params->clip_l_path);
|
||||
@@ -273,12 +273,12 @@ public:
|
||||
return false;
|
||||
}
|
||||
|
||||
auto& tensor_types = model_loader.tensor_storages_types;
|
||||
for (auto& item : tensor_types) {
|
||||
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||
if (contains(item.first, "qwen2vl") && ends_with(item.first, "weight") && (item.second == GGML_TYPE_F32 || item.second == GGML_TYPE_BF16)) {
|
||||
item.second = GGML_TYPE_F16;
|
||||
// LOG_DEBUG(" change %s %u", item.first.c_str(), item.second);
|
||||
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
||||
for (auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (contains(name, "qwen2vl") &&
|
||||
ends_with(name, "weight") &&
|
||||
(tensor_storage.type == GGML_TYPE_F32 || tensor_storage.type == GGML_TYPE_BF16)) {
|
||||
tensor_storage.expected_type = GGML_TYPE_F16;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -344,13 +344,13 @@ public:
|
||||
if (sd_version_is_sd3(version)) {
|
||||
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
tensor_storage_map);
|
||||
diffusion_model = std::make_shared<MMDiTModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
tensor_storage_map);
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
bool is_chroma = false;
|
||||
for (auto pair : model_loader.tensor_storages_types) {
|
||||
for (auto pair : tensor_storage_map) {
|
||||
if (pair.first.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
||||
is_chroma = true;
|
||||
break;
|
||||
@@ -368,42 +368,42 @@ public:
|
||||
|
||||
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
sd_ctx_params->chroma_use_t5_mask,
|
||||
sd_ctx_params->chroma_t5_mask_pad);
|
||||
} else {
|
||||
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
tensor_storage_map);
|
||||
}
|
||||
diffusion_model = std::make_shared<FluxModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
version,
|
||||
sd_ctx_params->chroma_use_dit_mask);
|
||||
} else if (sd_version_is_wan(version)) {
|
||||
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
true,
|
||||
1,
|
||||
true);
|
||||
diffusion_model = std::make_shared<WanModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"model.diffusion_model",
|
||||
version);
|
||||
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
|
||||
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"model.high_noise_diffusion_model",
|
||||
version);
|
||||
}
|
||||
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
tensor_storage_map);
|
||||
clip_vision->alloc_params_buffer();
|
||||
clip_vision->get_param_tensors(tensors);
|
||||
}
|
||||
@@ -414,32 +414,32 @@ public:
|
||||
}
|
||||
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"",
|
||||
enable_vision);
|
||||
diffusion_model = std::make_shared<QwenImageModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"model.diffusion_model",
|
||||
version);
|
||||
} else { // SD1.x SD2.x SDXL
|
||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
||||
version,
|
||||
PM_VERSION_2);
|
||||
} else {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
||||
version);
|
||||
}
|
||||
diffusion_model = std::make_shared<UNetModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
version);
|
||||
if (sd_ctx_params->diffusion_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the diffusion model");
|
||||
@@ -477,7 +477,7 @@ public:
|
||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"first_stage_model",
|
||||
vae_decode_only,
|
||||
version);
|
||||
@@ -489,7 +489,7 @@ public:
|
||||
} else if (!use_tiny_autoencoder) {
|
||||
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"first_stage_model",
|
||||
vae_decode_only,
|
||||
false,
|
||||
@@ -512,7 +512,7 @@ public:
|
||||
} else {
|
||||
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"decoder.layers",
|
||||
vae_decode_only,
|
||||
version);
|
||||
@@ -533,7 +533,7 @@ public:
|
||||
}
|
||||
control_net = std::make_shared<ControlNet>(controlnet_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
version);
|
||||
if (sd_ctx_params->diffusion_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the control net");
|
||||
@@ -544,7 +544,7 @@ public:
|
||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"pmid",
|
||||
version,
|
||||
PM_VERSION_2);
|
||||
@@ -552,7 +552,7 @@ public:
|
||||
} else {
|
||||
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
tensor_storage_map,
|
||||
"pmid",
|
||||
version);
|
||||
}
|
||||
@@ -733,12 +733,12 @@ public:
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
|
||||
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
|
||||
// CosXL models
|
||||
// TODO: get sigma_min and sigma_max values from file
|
||||
is_using_edm_v_parameterization = true;
|
||||
}
|
||||
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
|
||||
if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (version == VERSION_SVD) {
|
||||
@@ -758,10 +758,9 @@ public:
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 1.0f; // TODO: validate
|
||||
for (auto pair : model_loader.tensor_storages_types) {
|
||||
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
|
||||
for (const auto& [name, tensor_storage] : tensor_storage_map) {
|
||||
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
|
||||
shift = 1.15f;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user