feat: add easycache support (#940)

This commit is contained in:
rmatif
2025-11-19 16:19:32 +01:00
committed by GitHub
parent 28ffb6c13d
commit a14e2b321d
5 changed files with 541 additions and 32 deletions

View File

@@ -11,6 +11,7 @@
#include "control.hpp"
#include "denoiser.hpp"
#include "diffusion_model.hpp"
#include "easycache.hpp"
#include "esrgan.hpp"
#include "lora.hpp"
#include "pmid.hpp"
@@ -1481,11 +1482,12 @@ public:
const std::vector<float>& sigmas,
int start_merge_step,
SDCondition id_cond,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
ggml_tensor* denoise_mask = nullptr,
ggml_tensor* vace_context = nullptr,
float vace_strength = 1.f) {
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
ggml_tensor* denoise_mask = nullptr,
ggml_tensor* vace_context = nullptr,
float vace_strength = 1.f,
const sd_easycache_params_t* easycache_params = nullptr) {
if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) {
LOG_WARN("timestep shifting is only supported for SDXL models!");
shifted_timestep = 0;
@@ -1501,6 +1503,42 @@ public:
img_cfg_scale = cfg_scale;
}
EasyCacheState easycache_state;
bool easycache_enabled = false;
if (easycache_params != nullptr && easycache_params->enabled) {
bool easycache_supported = sd_version_is_dit(version);
if (!easycache_supported) {
LOG_WARN("EasyCache requested but not supported for this model type");
} else {
EasyCacheConfig easycache_config;
easycache_config.enabled = true;
easycache_config.reuse_threshold = std::max(0.0f, easycache_params->reuse_threshold);
easycache_config.start_percent = easycache_params->start_percent;
easycache_config.end_percent = easycache_params->end_percent;
bool percent_valid = easycache_config.start_percent >= 0.0f &&
easycache_config.start_percent < 1.0f &&
easycache_config.end_percent > 0.0f &&
easycache_config.end_percent <= 1.0f &&
easycache_config.start_percent < easycache_config.end_percent;
if (!percent_valid) {
LOG_WARN("EasyCache disabled due to invalid percent range (start=%.3f, end=%.3f)",
easycache_config.start_percent,
easycache_config.end_percent);
} else {
easycache_state.init(easycache_config, denoiser.get());
if (easycache_state.enabled()) {
easycache_enabled = true;
LOG_INFO("EasyCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f",
easycache_config.reuse_threshold,
easycache_config.start_percent,
easycache_config.end_percent);
} else {
LOG_WARN("EasyCache requested but could not be initialized for this run");
}
}
}
}
size_t steps = sigmas.size() - 1;
struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent);
copy_ggml_tensor(x, init_latent);
@@ -1571,6 +1609,38 @@ public:
pretty_progress(0, (int)steps, 0);
}
DiffusionParams diffusion_params;
const bool easycache_step_active = easycache_enabled && step > 0;
int easycache_step_index = easycache_step_active ? (step - 1) : -1;
if (easycache_step_active) {
easycache_state.begin_step(easycache_step_index, sigma);
}
auto easycache_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool {
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
return false;
}
return easycache_state.before_condition(condition,
diffusion_params.x,
output_tensor,
sigma,
easycache_step_index);
};
auto easycache_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) {
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
return;
}
easycache_state.after_condition(condition,
diffusion_params.x,
output_tensor);
};
auto easycache_step_is_skipped = [&]() {
return easycache_step_active && easycache_state.is_step_skipped();
};
std::vector<float> scaling = denoiser->get_scalings(sigma);
GGML_ASSERT(scaling.size() == 3);
float c_skip = scaling[0];
@@ -1616,7 +1686,6 @@ public:
// GGML_ASSERT(0);
}
DiffusionParams diffusion_params;
diffusion_params.x = noised_input;
diffusion_params.timesteps = timesteps;
diffusion_params.guidance = guidance_tensor;
@@ -1627,37 +1696,50 @@ public:
diffusion_params.vace_context = vace_context;
diffusion_params.vace_strength = vace_strength;
const SDCondition* active_condition = nullptr;
struct ggml_tensor** active_output = &out_cond;
if (start_merge_step == -1 || step <= start_merge_step) {
// cond
diffusion_params.context = cond.c_crossattn;
diffusion_params.c_concat = cond.c_concat;
diffusion_params.y = cond.c_vector;
work_diffusion_model->compute(n_threads,
diffusion_params,
&out_cond);
active_condition = &cond;
} else {
diffusion_params.context = id_cond.c_crossattn;
diffusion_params.c_concat = cond.c_concat;
diffusion_params.y = id_cond.c_vector;
active_condition = &id_cond;
}
bool skip_model = easycache_before_condition(active_condition, *active_output);
if (!skip_model) {
work_diffusion_model->compute(n_threads,
diffusion_params,
&out_cond);
active_output);
easycache_after_condition(active_condition, *active_output);
}
bool current_step_skipped = easycache_step_is_skipped();
float* negative_data = nullptr;
if (has_unconditioned) {
// uncond
if (control_hint != nullptr && control_net != nullptr) {
if (!current_step_skipped && control_hint != nullptr && control_net != nullptr) {
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
controls = control_net->controls;
}
current_step_skipped = easycache_step_is_skipped();
diffusion_params.controls = controls;
diffusion_params.context = uncond.c_crossattn;
diffusion_params.c_concat = uncond.c_concat;
diffusion_params.y = uncond.c_vector;
work_diffusion_model->compute(n_threads,
diffusion_params,
&out_uncond);
bool skip_uncond = easycache_before_condition(&uncond, out_uncond);
if (!skip_uncond) {
work_diffusion_model->compute(n_threads,
diffusion_params,
&out_uncond);
easycache_after_condition(&uncond, out_uncond);
}
negative_data = (float*)out_uncond->data;
}
@@ -1666,25 +1748,31 @@ public:
diffusion_params.context = img_cond.c_crossattn;
diffusion_params.c_concat = img_cond.c_concat;
diffusion_params.y = img_cond.c_vector;
work_diffusion_model->compute(n_threads,
diffusion_params,
&out_img_cond);
bool skip_img_cond = easycache_before_condition(&img_cond, out_img_cond);
if (!skip_img_cond) {
work_diffusion_model->compute(n_threads,
diffusion_params,
&out_img_cond);
easycache_after_condition(&img_cond, out_img_cond);
}
img_cond_data = (float*)out_img_cond->data;
}
int step_count = sigmas.size();
bool is_skiplayer_step = has_skiplayer && step > (int)(guidance.slg.layer_start * step_count) && step < (int)(guidance.slg.layer_end * step_count);
float* skip_layer_data = nullptr;
float* skip_layer_data = has_skiplayer ? (float*)out_skip->data : nullptr;
if (is_skiplayer_step) {
LOG_DEBUG("Skipping layers at step %d\n", step);
// skip layer (same as conditionned)
diffusion_params.context = cond.c_crossattn;
diffusion_params.c_concat = cond.c_concat;
diffusion_params.y = cond.c_vector;
diffusion_params.skip_layers = skip_layers;
work_diffusion_model->compute(n_threads,
diffusion_params,
&out_skip);
if (!easycache_step_is_skipped()) {
// skip layer (same as conditioned)
diffusion_params.context = cond.c_crossattn;
diffusion_params.c_concat = cond.c_concat;
diffusion_params.y = cond.c_vector;
diffusion_params.skip_layers = skip_layers;
work_diffusion_model->compute(n_threads,
diffusion_params,
&out_skip);
}
skip_layer_data = (float*)out_skip->data;
}
float* vec_denoised = (float*)denoised->data;
@@ -1748,6 +1836,26 @@ public:
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
if (easycache_enabled) {
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
if (easycache_state.total_steps_skipped > 0 && total_steps > 0) {
if (easycache_state.total_steps_skipped < static_cast<int>(total_steps)) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - easycache_state.total_steps_skipped);
LOG_INFO("EasyCache skipped %d/%zu steps (%.2fx estimated speedup)",
easycache_state.total_steps_skipped,
total_steps,
speedup);
} else {
LOG_INFO("EasyCache skipped %d/%zu steps",
easycache_state.total_steps_skipped,
total_steps);
}
} else if (total_steps > 0) {
LOG_INFO("EasyCache completed without skipping steps");
}
}
if (inverse_noise_scaling) {
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
}
@@ -2294,6 +2402,14 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) {
return LORA_APPLY_MODE_COUNT;
}
void sd_easycache_params_init(sd_easycache_params_t* easycache_params) {
*easycache_params = {};
easycache_params->enabled = false;
easycache_params->reuse_threshold = 0.2f;
easycache_params->start_percent = 0.15f;
easycache_params->end_percent = 0.95f;
}
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
*sd_ctx_params = {};
sd_ctx_params->vae_decode_only = true;
@@ -2452,6 +2568,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->control_strength = 0.9f;
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_easycache_params_init(&sd_img_gen_params->easycache);
}
char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
@@ -2495,6 +2612,12 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->pm_params.id_images_count,
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled));
snprintf(buf + strlen(buf), 4096 - strlen(buf),
"easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
sd_img_gen_params->easycache.enabled ? "enabled" : "disabled",
sd_img_gen_params->easycache.reuse_threshold,
sd_img_gen_params->easycache.start_percent,
sd_img_gen_params->easycache.end_percent);
free(sample_params_str);
return buf;
}
@@ -2511,6 +2634,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
sd_vid_gen_params->video_frames = 6;
sd_vid_gen_params->moe_boundary = 0.875f;
sd_vid_gen_params->vace_strength = 1.f;
sd_easycache_params_init(&sd_vid_gen_params->easycache);
}
struct sd_ctx_t {
@@ -2578,8 +2702,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
std::vector<sd_image_t*> ref_images,
std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index,
ggml_tensor* concat_latent = nullptr,
ggml_tensor* denoise_mask = nullptr) {
ggml_tensor* concat_latent = nullptr,
ggml_tensor* denoise_mask = nullptr,
const sd_easycache_params_t* easycache_params = nullptr) {
if (seed < 0) {
// Generally, when using the provided command line, the seed is always >0.
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -2868,7 +2993,10 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
id_cond,
ref_latents,
increase_ref_index,
denoise_mask);
denoise_mask,
nullptr,
1.0f,
easycache_params);
// print_ggml_tensor(x_0);
int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
@@ -3185,7 +3313,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
ref_latents,
sd_img_gen_params->increase_ref_index,
concat_latent,
denoise_mask);
denoise_mask,
&sd_img_gen_params->easycache);
size_t t2 = ggml_time_ms();
@@ -3506,7 +3635,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
false,
denoise_mask,
vace_context,
sd_vid_gen_params->vace_strength);
sd_vid_gen_params->vace_strength,
&sd_vid_gen_params->easycache);
int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
@@ -3542,7 +3672,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
false,
denoise_mask,
vace_context,
sd_vid_gen_params->vace_strength);
sd_vid_gen_params->vace_strength,
&sd_vid_gen_params->easycache);
int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);