mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 11:11:19 +01:00
feat: add sgm_uniform scheduler, simple scheduler, and support for NitroFusion (#675)
* feat: Add timestep shift and two new schedulers * update readme * fix spaces * format code * simplify SGMUniformSchedule * simplify shifted_timestep logic * avoid conflict --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
@@ -747,6 +747,16 @@ public:
|
||||
denoiser->scheduler = std::make_shared<GITSSchedule>();
|
||||
denoiser->scheduler->version = version;
|
||||
break;
|
||||
case SGM_UNIFORM:
|
||||
LOG_INFO("Running with SGM Uniform schedule");
|
||||
denoiser->scheduler = std::make_shared<SGMUniformSchedule>();
|
||||
denoiser->scheduler->version = version;
|
||||
break;
|
||||
case SIMPLE:
|
||||
LOG_INFO("Running with Simple schedule");
|
||||
denoiser->scheduler = std::make_shared<SimpleSchedule>();
|
||||
denoiser->scheduler->version = version;
|
||||
break;
|
||||
case SMOOTHSTEP:
|
||||
LOG_INFO("Running with SmoothStep scheduler");
|
||||
denoiser->scheduler = std::make_shared<SmoothStepSchedule>();
|
||||
@@ -1033,6 +1043,7 @@ public:
|
||||
float control_strength,
|
||||
sd_guidance_params_t guidance,
|
||||
float eta,
|
||||
int shifted_timestep,
|
||||
sample_method_t method,
|
||||
const std::vector<float>& sigmas,
|
||||
int start_merge_step,
|
||||
@@ -1042,6 +1053,10 @@ public:
|
||||
ggml_tensor* denoise_mask = NULL,
|
||||
ggml_tensor* vace_context = NULL,
|
||||
float vace_strength = 1.f) {
|
||||
if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) {
|
||||
LOG_WARN("timestep shifting is only supported for SDXL models!");
|
||||
shifted_timestep = 0;
|
||||
}
|
||||
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
|
||||
|
||||
float cfg_scale = guidance.txt_cfg;
|
||||
@@ -1102,7 +1117,17 @@ public:
|
||||
float c_in = scaling[2];
|
||||
|
||||
float t = denoiser->sigma_to_t(sigma);
|
||||
std::vector<float> timesteps_vec(1, t); // [N, ]
|
||||
std::vector<float> timesteps_vec;
|
||||
if (shifted_timestep > 0 && sd_version_is_sdxl(version)) {
|
||||
float shifted_t_float = t * (float(shifted_timestep) / float(TIMESTEPS));
|
||||
int64_t shifted_t = static_cast<int64_t>(roundf(shifted_t_float));
|
||||
shifted_t = std::max((int64_t)0, std::min((int64_t)(TIMESTEPS - 1), shifted_t));
|
||||
LOG_DEBUG("shifting timestep from %.2f to %" PRId64 " (sigma: %.4f)", t, shifted_t, sigma);
|
||||
timesteps_vec.assign(1, (float)shifted_t);
|
||||
} else {
|
||||
timesteps_vec.assign(1, t);
|
||||
}
|
||||
|
||||
timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask);
|
||||
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||
std::vector<float> guidance_vec(1, guidance.distilled_guidance);
|
||||
@@ -1200,6 +1225,19 @@ public:
|
||||
float* vec_input = (float*)input->data;
|
||||
float* positive_data = (float*)out_cond->data;
|
||||
int ne_elements = (int)ggml_nelements(denoised);
|
||||
|
||||
if (shifted_timestep > 0 && sd_version_is_sdxl(version)) {
|
||||
int64_t shifted_t_idx = static_cast<int64_t>(roundf(timesteps_vec[0]));
|
||||
float shifted_sigma = denoiser->t_to_sigma((float)shifted_t_idx);
|
||||
std::vector<float> shifted_scaling = denoiser->get_scalings(shifted_sigma);
|
||||
float shifted_c_skip = shifted_scaling[0];
|
||||
float shifted_c_out = shifted_scaling[1];
|
||||
float shifted_c_in = shifted_scaling[2];
|
||||
|
||||
c_skip = shifted_c_skip * c_in / shifted_c_in;
|
||||
c_out = shifted_c_out;
|
||||
}
|
||||
|
||||
for (int i = 0; i < ne_elements; i++) {
|
||||
float latent_result = positive_data[i];
|
||||
if (has_unconditioned) {
|
||||
@@ -1222,6 +1260,7 @@ public:
|
||||
// denoised = (v * c_out + input * c_skip) or (input + eps * c_out)
|
||||
vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip;
|
||||
}
|
||||
|
||||
int64_t t1 = ggml_time_us();
|
||||
if (step > 0) {
|
||||
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
|
||||
@@ -1588,6 +1627,8 @@ const char* schedule_to_str[] = {
|
||||
"exponential",
|
||||
"ays",
|
||||
"gits",
|
||||
"sgm_uniform",
|
||||
"simple",
|
||||
"smoothstep",
|
||||
};
|
||||
|
||||
@@ -1720,7 +1761,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
|
||||
"scheduler: %s, "
|
||||
"sample_method: %s, "
|
||||
"sample_steps: %d, "
|
||||
"eta: %.2f)",
|
||||
"eta: %.2f, "
|
||||
"shifted_timestep: %d)",
|
||||
sample_params->guidance.txt_cfg,
|
||||
sample_params->guidance.img_cfg,
|
||||
sample_params->guidance.distilled_guidance,
|
||||
@@ -1731,7 +1773,8 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
|
||||
sd_schedule_name(sample_params->scheduler),
|
||||
sd_sample_method_name(sample_params->sample_method),
|
||||
sample_params->sample_steps,
|
||||
sample_params->eta);
|
||||
sample_params->eta,
|
||||
sample_params->shifted_timestep);
|
||||
|
||||
return buf;
|
||||
}
|
||||
@@ -1863,6 +1906,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
int clip_skip,
|
||||
sd_guidance_params_t guidance,
|
||||
float eta,
|
||||
int shifted_timestep,
|
||||
int width,
|
||||
int height,
|
||||
enum sample_method_t sample_method,
|
||||
@@ -2101,6 +2145,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
control_strength,
|
||||
guidance,
|
||||
eta,
|
||||
shifted_timestep,
|
||||
sample_method,
|
||||
sigmas,
|
||||
start_merge_step,
|
||||
@@ -2394,6 +2439,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
||||
sd_img_gen_params->clip_skip,
|
||||
sd_img_gen_params->sample_params.guidance,
|
||||
sd_img_gen_params->sample_params.eta,
|
||||
sd_img_gen_params->sample_params.shifted_timestep,
|
||||
width,
|
||||
height,
|
||||
sample_method,
|
||||
@@ -2734,6 +2780,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
0,
|
||||
sd_vid_gen_params->high_noise_sample_params.guidance,
|
||||
sd_vid_gen_params->high_noise_sample_params.eta,
|
||||
sd_vid_gen_params->high_noise_sample_params.shifted_timestep,
|
||||
sd_vid_gen_params->high_noise_sample_params.sample_method,
|
||||
high_noise_sigmas,
|
||||
-1,
|
||||
@@ -2769,6 +2816,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
0,
|
||||
sd_vid_gen_params->sample_params.guidance,
|
||||
sd_vid_gen_params->sample_params.eta,
|
||||
sd_vid_gen_params->sample_params.shifted_timestep,
|
||||
sd_vid_gen_params->sample_params.sample_method,
|
||||
sigmas,
|
||||
-1,
|
||||
|
||||
Reference in New Issue
Block a user