refactor: simplify the model loading logic (#933)

* remove String2GGMLType

* remove preprocess_tensor

* fix clip init

* simplify the logic for reading weights
This commit is contained in:
leejet
2025-11-03 21:21:34 +08:00
committed by GitHub
parent 6103d86e2c
commit 8f6c5c217b
21 changed files with 534 additions and 622 deletions

View File

@@ -213,7 +213,7 @@ public:
}
}
bool is_unet = model_loader.model_is_unet();
bool is_unet = sd_version_is_unet(model_loader.get_sd_version());
if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) {
LOG_INFO("loading clip_l from '%s'", sd_ctx_params->clip_l_path);
@@ -273,12 +273,12 @@ public:
return false;
}
auto& tensor_types = model_loader.tensor_storages_types;
for (auto& item : tensor_types) {
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
if (contains(item.first, "qwen2vl") && ends_with(item.first, "weight") && (item.second == GGML_TYPE_F32 || item.second == GGML_TYPE_BF16)) {
item.second = GGML_TYPE_F16;
// LOG_DEBUG(" change %s %u", item.first.c_str(), item.second);
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (contains(name, "qwen2vl") &&
ends_with(name, "weight") &&
(tensor_storage.type == GGML_TYPE_F32 || tensor_storage.type == GGML_TYPE_BF16)) {
tensor_storage.expected_type = GGML_TYPE_F16;
}
}
@@ -344,13 +344,13 @@ public:
if (sd_version_is_sd3(version)) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
tensor_storage_map);
diffusion_model = std::make_shared<MMDiTModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
tensor_storage_map);
} else if (sd_version_is_flux(version)) {
bool is_chroma = false;
for (auto pair : model_loader.tensor_storages_types) {
for (auto pair : tensor_storage_map) {
if (pair.first.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
is_chroma = true;
break;
@@ -368,42 +368,42 @@ public:
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
sd_ctx_params->chroma_use_t5_mask,
sd_ctx_params->chroma_t5_mask_pad);
} else {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
tensor_storage_map);
}
diffusion_model = std::make_shared<FluxModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
version,
sd_ctx_params->chroma_use_dit_mask);
} else if (sd_version_is_wan(version)) {
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
true,
1,
true);
diffusion_model = std::make_shared<WanModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"model.diffusion_model",
version);
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"model.high_noise_diffusion_model",
version);
}
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
tensor_storage_map);
clip_vision->alloc_params_buffer();
clip_vision->get_param_tensors(tensors);
}
@@ -414,32 +414,32 @@ public:
}
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"",
enable_vision);
diffusion_model = std::make_shared<QwenImageModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"model.diffusion_model",
version);
} else { // SD1.x SD2.x SDXL
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
SAFE_STR(sd_ctx_params->embedding_dir),
version,
PM_VERSION_2);
} else {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
SAFE_STR(sd_ctx_params->embedding_dir),
version);
}
diffusion_model = std::make_shared<UNetModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
version);
if (sd_ctx_params->diffusion_conv_direct) {
LOG_INFO("Using Conv2d direct in the diffusion model");
@@ -477,7 +477,7 @@ public:
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
version);
@@ -489,7 +489,7 @@ public:
} else if (!use_tiny_autoencoder) {
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
false,
@@ -512,7 +512,7 @@ public:
} else {
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"decoder.layers",
vae_decode_only,
version);
@@ -533,7 +533,7 @@ public:
}
control_net = std::make_shared<ControlNet>(controlnet_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
version);
if (sd_ctx_params->diffusion_conv_direct) {
LOG_INFO("Using Conv2d direct in the control net");
@@ -544,7 +544,7 @@ public:
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"pmid",
version,
PM_VERSION_2);
@@ -552,7 +552,7 @@ public:
} else {
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
tensor_storage_map,
"pmid",
version);
}
@@ -733,12 +733,12 @@ public:
is_using_v_parameterization = true;
}
} else if (sd_version_is_sdxl(version)) {
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
if (tensor_storage_map.find("edm_vpred.sigma_max") != tensor_storage_map.end()) {
// CosXL models
// TODO: get sigma_min and sigma_max values from file
is_using_edm_v_parameterization = true;
}
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
if (tensor_storage_map.find("v_pred") != tensor_storage_map.end()) {
is_using_v_parameterization = true;
}
} else if (version == VERSION_SVD) {
@@ -758,10 +758,9 @@ public:
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 1.0f; // TODO: validate
for (auto pair : model_loader.tensor_storages_types) {
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
for (const auto& [name, tensor_storage] : tensor_storage_map) {
if (starts_with(name, "model.diffusion_model.guidance_in.in_layer.weight")) {
shift = 1.15f;
break;
}
}
}