mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
feat: add easycache support (#940)
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "control.hpp"
|
||||
#include "denoiser.hpp"
|
||||
#include "diffusion_model.hpp"
|
||||
#include "easycache.hpp"
|
||||
#include "esrgan.hpp"
|
||||
#include "lora.hpp"
|
||||
#include "pmid.hpp"
|
||||
@@ -1481,11 +1482,12 @@ public:
|
||||
const std::vector<float>& sigmas,
|
||||
int start_merge_step,
|
||||
SDCondition id_cond,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false,
|
||||
ggml_tensor* denoise_mask = nullptr,
|
||||
ggml_tensor* vace_context = nullptr,
|
||||
float vace_strength = 1.f) {
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false,
|
||||
ggml_tensor* denoise_mask = nullptr,
|
||||
ggml_tensor* vace_context = nullptr,
|
||||
float vace_strength = 1.f,
|
||||
const sd_easycache_params_t* easycache_params = nullptr) {
|
||||
if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) {
|
||||
LOG_WARN("timestep shifting is only supported for SDXL models!");
|
||||
shifted_timestep = 0;
|
||||
@@ -1501,6 +1503,42 @@ public:
|
||||
img_cfg_scale = cfg_scale;
|
||||
}
|
||||
|
||||
EasyCacheState easycache_state;
|
||||
bool easycache_enabled = false;
|
||||
if (easycache_params != nullptr && easycache_params->enabled) {
|
||||
bool easycache_supported = sd_version_is_dit(version);
|
||||
if (!easycache_supported) {
|
||||
LOG_WARN("EasyCache requested but not supported for this model type");
|
||||
} else {
|
||||
EasyCacheConfig easycache_config;
|
||||
easycache_config.enabled = true;
|
||||
easycache_config.reuse_threshold = std::max(0.0f, easycache_params->reuse_threshold);
|
||||
easycache_config.start_percent = easycache_params->start_percent;
|
||||
easycache_config.end_percent = easycache_params->end_percent;
|
||||
bool percent_valid = easycache_config.start_percent >= 0.0f &&
|
||||
easycache_config.start_percent < 1.0f &&
|
||||
easycache_config.end_percent > 0.0f &&
|
||||
easycache_config.end_percent <= 1.0f &&
|
||||
easycache_config.start_percent < easycache_config.end_percent;
|
||||
if (!percent_valid) {
|
||||
LOG_WARN("EasyCache disabled due to invalid percent range (start=%.3f, end=%.3f)",
|
||||
easycache_config.start_percent,
|
||||
easycache_config.end_percent);
|
||||
} else {
|
||||
easycache_state.init(easycache_config, denoiser.get());
|
||||
if (easycache_state.enabled()) {
|
||||
easycache_enabled = true;
|
||||
LOG_INFO("EasyCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f",
|
||||
easycache_config.reuse_threshold,
|
||||
easycache_config.start_percent,
|
||||
easycache_config.end_percent);
|
||||
} else {
|
||||
LOG_WARN("EasyCache requested but could not be initialized for this run");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t steps = sigmas.size() - 1;
|
||||
struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent);
|
||||
copy_ggml_tensor(x, init_latent);
|
||||
@@ -1571,6 +1609,38 @@ public:
|
||||
pretty_progress(0, (int)steps, 0);
|
||||
}
|
||||
|
||||
DiffusionParams diffusion_params;
|
||||
|
||||
const bool easycache_step_active = easycache_enabled && step > 0;
|
||||
int easycache_step_index = easycache_step_active ? (step - 1) : -1;
|
||||
if (easycache_step_active) {
|
||||
easycache_state.begin_step(easycache_step_index, sigma);
|
||||
}
|
||||
|
||||
auto easycache_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool {
|
||||
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return easycache_state.before_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor,
|
||||
sigma,
|
||||
easycache_step_index);
|
||||
};
|
||||
|
||||
auto easycache_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) {
|
||||
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
easycache_state.after_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor);
|
||||
};
|
||||
|
||||
auto easycache_step_is_skipped = [&]() {
|
||||
return easycache_step_active && easycache_state.is_step_skipped();
|
||||
};
|
||||
|
||||
std::vector<float> scaling = denoiser->get_scalings(sigma);
|
||||
GGML_ASSERT(scaling.size() == 3);
|
||||
float c_skip = scaling[0];
|
||||
@@ -1616,7 +1686,6 @@ public:
|
||||
// GGML_ASSERT(0);
|
||||
}
|
||||
|
||||
DiffusionParams diffusion_params;
|
||||
diffusion_params.x = noised_input;
|
||||
diffusion_params.timesteps = timesteps;
|
||||
diffusion_params.guidance = guidance_tensor;
|
||||
@@ -1627,37 +1696,50 @@ public:
|
||||
diffusion_params.vace_context = vace_context;
|
||||
diffusion_params.vace_strength = vace_strength;
|
||||
|
||||
const SDCondition* active_condition = nullptr;
|
||||
struct ggml_tensor** active_output = &out_cond;
|
||||
if (start_merge_step == -1 || step <= start_merge_step) {
|
||||
// cond
|
||||
diffusion_params.context = cond.c_crossattn;
|
||||
diffusion_params.c_concat = cond.c_concat;
|
||||
diffusion_params.y = cond.c_vector;
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_cond);
|
||||
active_condition = &cond;
|
||||
} else {
|
||||
diffusion_params.context = id_cond.c_crossattn;
|
||||
diffusion_params.c_concat = cond.c_concat;
|
||||
diffusion_params.y = id_cond.c_vector;
|
||||
active_condition = &id_cond;
|
||||
}
|
||||
|
||||
bool skip_model = easycache_before_condition(active_condition, *active_output);
|
||||
if (!skip_model) {
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_cond);
|
||||
active_output);
|
||||
easycache_after_condition(active_condition, *active_output);
|
||||
}
|
||||
|
||||
bool current_step_skipped = easycache_step_is_skipped();
|
||||
|
||||
float* negative_data = nullptr;
|
||||
if (has_unconditioned) {
|
||||
// uncond
|
||||
if (control_hint != nullptr && control_net != nullptr) {
|
||||
if (!current_step_skipped && control_hint != nullptr && control_net != nullptr) {
|
||||
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
|
||||
controls = control_net->controls;
|
||||
}
|
||||
current_step_skipped = easycache_step_is_skipped();
|
||||
diffusion_params.controls = controls;
|
||||
diffusion_params.context = uncond.c_crossattn;
|
||||
diffusion_params.c_concat = uncond.c_concat;
|
||||
diffusion_params.y = uncond.c_vector;
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_uncond);
|
||||
bool skip_uncond = easycache_before_condition(&uncond, out_uncond);
|
||||
if (!skip_uncond) {
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_uncond);
|
||||
easycache_after_condition(&uncond, out_uncond);
|
||||
}
|
||||
negative_data = (float*)out_uncond->data;
|
||||
}
|
||||
|
||||
@@ -1666,25 +1748,31 @@ public:
|
||||
diffusion_params.context = img_cond.c_crossattn;
|
||||
diffusion_params.c_concat = img_cond.c_concat;
|
||||
diffusion_params.y = img_cond.c_vector;
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_img_cond);
|
||||
bool skip_img_cond = easycache_before_condition(&img_cond, out_img_cond);
|
||||
if (!skip_img_cond) {
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_img_cond);
|
||||
easycache_after_condition(&img_cond, out_img_cond);
|
||||
}
|
||||
img_cond_data = (float*)out_img_cond->data;
|
||||
}
|
||||
|
||||
int step_count = sigmas.size();
|
||||
bool is_skiplayer_step = has_skiplayer && step > (int)(guidance.slg.layer_start * step_count) && step < (int)(guidance.slg.layer_end * step_count);
|
||||
float* skip_layer_data = nullptr;
|
||||
float* skip_layer_data = has_skiplayer ? (float*)out_skip->data : nullptr;
|
||||
if (is_skiplayer_step) {
|
||||
LOG_DEBUG("Skipping layers at step %d\n", step);
|
||||
// skip layer (same as conditionned)
|
||||
diffusion_params.context = cond.c_crossattn;
|
||||
diffusion_params.c_concat = cond.c_concat;
|
||||
diffusion_params.y = cond.c_vector;
|
||||
diffusion_params.skip_layers = skip_layers;
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_skip);
|
||||
if (!easycache_step_is_skipped()) {
|
||||
// skip layer (same as conditioned)
|
||||
diffusion_params.context = cond.c_crossattn;
|
||||
diffusion_params.c_concat = cond.c_concat;
|
||||
diffusion_params.y = cond.c_vector;
|
||||
diffusion_params.skip_layers = skip_layers;
|
||||
work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_skip);
|
||||
}
|
||||
skip_layer_data = (float*)out_skip->data;
|
||||
}
|
||||
float* vec_denoised = (float*)denoised->data;
|
||||
@@ -1748,6 +1836,26 @@ public:
|
||||
|
||||
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
|
||||
|
||||
if (easycache_enabled) {
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
if (easycache_state.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (easycache_state.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - easycache_state.total_steps_skipped);
|
||||
LOG_INFO("EasyCache skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
easycache_state.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup);
|
||||
} else {
|
||||
LOG_INFO("EasyCache skipped %d/%zu steps",
|
||||
easycache_state.total_steps_skipped,
|
||||
total_steps);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("EasyCache completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (inverse_noise_scaling) {
|
||||
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
||||
}
|
||||
@@ -2294,6 +2402,14 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) {
|
||||
return LORA_APPLY_MODE_COUNT;
|
||||
}
|
||||
|
||||
void sd_easycache_params_init(sd_easycache_params_t* easycache_params) {
|
||||
*easycache_params = {};
|
||||
easycache_params->enabled = false;
|
||||
easycache_params->reuse_threshold = 0.2f;
|
||||
easycache_params->start_percent = 0.15f;
|
||||
easycache_params->end_percent = 0.95f;
|
||||
}
|
||||
|
||||
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
|
||||
*sd_ctx_params = {};
|
||||
sd_ctx_params->vae_decode_only = true;
|
||||
@@ -2452,6 +2568,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
|
||||
sd_img_gen_params->control_strength = 0.9f;
|
||||
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
|
||||
sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
sd_easycache_params_init(&sd_img_gen_params->easycache);
|
||||
}
|
||||
|
||||
char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
@@ -2495,6 +2612,12 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
sd_img_gen_params->pm_params.id_images_count,
|
||||
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
|
||||
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled));
|
||||
snprintf(buf + strlen(buf), 4096 - strlen(buf),
|
||||
"easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
|
||||
sd_img_gen_params->easycache.enabled ? "enabled" : "disabled",
|
||||
sd_img_gen_params->easycache.reuse_threshold,
|
||||
sd_img_gen_params->easycache.start_percent,
|
||||
sd_img_gen_params->easycache.end_percent);
|
||||
free(sample_params_str);
|
||||
return buf;
|
||||
}
|
||||
@@ -2511,6 +2634,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
||||
sd_vid_gen_params->video_frames = 6;
|
||||
sd_vid_gen_params->moe_boundary = 0.875f;
|
||||
sd_vid_gen_params->vace_strength = 1.f;
|
||||
sd_easycache_params_init(&sd_vid_gen_params->easycache);
|
||||
}
|
||||
|
||||
struct sd_ctx_t {
|
||||
@@ -2578,8 +2702,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
std::vector<sd_image_t*> ref_images,
|
||||
std::vector<ggml_tensor*> ref_latents,
|
||||
bool increase_ref_index,
|
||||
ggml_tensor* concat_latent = nullptr,
|
||||
ggml_tensor* denoise_mask = nullptr) {
|
||||
ggml_tensor* concat_latent = nullptr,
|
||||
ggml_tensor* denoise_mask = nullptr,
|
||||
const sd_easycache_params_t* easycache_params = nullptr) {
|
||||
if (seed < 0) {
|
||||
// Generally, when using the provided command line, the seed is always >0.
|
||||
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
|
||||
@@ -2868,7 +2993,10 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
id_cond,
|
||||
ref_latents,
|
||||
increase_ref_index,
|
||||
denoise_mask);
|
||||
denoise_mask,
|
||||
nullptr,
|
||||
1.0f,
|
||||
easycache_params);
|
||||
// print_ggml_tensor(x_0);
|
||||
int64_t sampling_end = ggml_time_ms();
|
||||
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
@@ -3185,7 +3313,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
||||
ref_latents,
|
||||
sd_img_gen_params->increase_ref_index,
|
||||
concat_latent,
|
||||
denoise_mask);
|
||||
denoise_mask,
|
||||
&sd_img_gen_params->easycache);
|
||||
|
||||
size_t t2 = ggml_time_ms();
|
||||
|
||||
@@ -3506,7 +3635,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
false,
|
||||
denoise_mask,
|
||||
vace_context,
|
||||
sd_vid_gen_params->vace_strength);
|
||||
sd_vid_gen_params->vace_strength,
|
||||
&sd_vid_gen_params->easycache);
|
||||
|
||||
int64_t sampling_end = ggml_time_ms();
|
||||
LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
@@ -3542,7 +3672,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
false,
|
||||
denoise_mask,
|
||||
vace_context,
|
||||
sd_vid_gen_params->vace_strength);
|
||||
sd_vid_gen_params->vace_strength,
|
||||
&sd_vid_gen_params->easycache);
|
||||
|
||||
int64_t sampling_end = ggml_time_ms();
|
||||
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
|
||||
Reference in New Issue
Block a user