mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
refactor: introduce GGMLRunnerContext (#928)
* introduce GGMLRunnerContext * add Flash Attention enable control through GGMLRunnerContext * add conv2d_direct enable control through GGMLRunnerContext
This commit is contained in:
@@ -341,16 +341,12 @@ public:
|
||||
LOG_INFO("CLIP: Using CPU backend");
|
||||
clip_backend = ggml_backend_cpu_init();
|
||||
}
|
||||
if (sd_ctx_params->diffusion_flash_attn) {
|
||||
LOG_INFO("Using flash attention in the diffusion model");
|
||||
}
|
||||
if (sd_version_is_sd3(version)) {
|
||||
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
diffusion_model = std::make_shared<MMDiTModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
sd_ctx_params->diffusion_flash_attn,
|
||||
model_loader.tensor_storages_types);
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
bool is_chroma = false;
|
||||
@@ -384,7 +380,6 @@ public:
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
version,
|
||||
sd_ctx_params->diffusion_flash_attn,
|
||||
sd_ctx_params->chroma_use_dit_mask);
|
||||
} else if (sd_version_is_wan(version)) {
|
||||
cond_stage_model = std::make_shared<T5CLIPEmbedder>(clip_backend,
|
||||
@@ -397,15 +392,13 @@ public:
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
"model.diffusion_model",
|
||||
version,
|
||||
sd_ctx_params->diffusion_flash_attn);
|
||||
version);
|
||||
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
|
||||
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
"model.high_noise_diffusion_model",
|
||||
version,
|
||||
sd_ctx_params->diffusion_flash_attn);
|
||||
version);
|
||||
}
|
||||
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
|
||||
@@ -428,8 +421,7 @@ public:
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
"model.diffusion_model",
|
||||
version,
|
||||
sd_ctx_params->diffusion_flash_attn);
|
||||
version);
|
||||
} else { // SD1.x SD2.x SDXL
|
||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||
@@ -448,14 +440,18 @@ public:
|
||||
diffusion_model = std::make_shared<UNetModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
version,
|
||||
sd_ctx_params->diffusion_flash_attn);
|
||||
version);
|
||||
if (sd_ctx_params->diffusion_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the diffusion model");
|
||||
std::dynamic_pointer_cast<UNetModel>(diffusion_model)->unet.enable_conv2d_direct();
|
||||
std::dynamic_pointer_cast<UNetModel>(diffusion_model)->unet.set_conv2d_direct_enabled(true);
|
||||
}
|
||||
}
|
||||
|
||||
if (sd_ctx_params->diffusion_flash_attn) {
|
||||
LOG_INFO("Using flash attention in the diffusion model");
|
||||
diffusion_model->set_flash_attn_enabled(true);
|
||||
}
|
||||
|
||||
cond_stage_model->alloc_params_buffer();
|
||||
cond_stage_model->get_param_tensors(tensors);
|
||||
|
||||
@@ -500,7 +496,7 @@ public:
|
||||
version);
|
||||
if (sd_ctx_params->vae_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the vae model");
|
||||
first_stage_model->enable_conv2d_direct();
|
||||
first_stage_model->set_conv2d_direct_enabled(true);
|
||||
}
|
||||
if (version == VERSION_SDXL &&
|
||||
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
|
||||
@@ -522,7 +518,7 @@ public:
|
||||
version);
|
||||
if (sd_ctx_params->vae_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the tae model");
|
||||
tae_first_stage->enable_conv2d_direct();
|
||||
tae_first_stage->set_conv2d_direct_enabled(true);
|
||||
}
|
||||
}
|
||||
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
|
||||
@@ -541,7 +537,7 @@ public:
|
||||
version);
|
||||
if (sd_ctx_params->diffusion_conv_direct) {
|
||||
LOG_INFO("Using Conv2d direct in the control net");
|
||||
control_net->enable_conv2d_direct();
|
||||
control_net->set_conv2d_direct_enabled(true);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user