feat: use Euler sampling by default for SD3 and Flux (#753)

Thank you for your contribution.
This commit is contained in:
Wagner Bruna
2025-09-14 01:34:41 -03:00
committed by GitHub
parent b54bec3f18
commit c607fc3ed4
4 changed files with 33 additions and 9 deletions

View File

@@ -43,7 +43,7 @@ const char* model_version_to_str[] = {
};
const char* sampling_methods_str[] = {
"Euler A",
"default",
"Euler",
"Heun",
"DPM2",
@@ -55,6 +55,7 @@ const char* sampling_methods_str[] = {
"LCM",
"DDIM \"trailing\"",
"TCD",
"Euler A",
};
/*================================================== Helper Functions ================================================*/
@@ -1500,7 +1501,7 @@ enum rng_type_t str_to_rng_type(const char* str) {
}
const char* sample_method_to_str[] = {
"euler_a",
"default",
"euler",
"heun",
"dpm2",
@@ -1512,6 +1513,7 @@ const char* sample_method_to_str[] = {
"lcm",
"ddim_trailing",
"tcd",
"euler_a",
};
const char* sd_sample_method_name(enum sample_method_t sample_method) {
@@ -1650,7 +1652,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
sample_params->guidance.slg.layer_end = 0.2f;
sample_params->guidance.slg.scale = 0.f;
sample_params->scheduler = DEFAULT;
sample_params->sample_method = EULER_A;
sample_params->sample_method = SAMPLE_METHOD_DEFAULT;
sample_params->sample_steps = 20;
}
@@ -1792,6 +1794,17 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
free(sd_ctx);
}
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
if (sd_ctx != NULL && sd_ctx->sd != NULL) {
SDVersion version = sd_ctx->sd->version;
if (sd_version_is_dit(version))
return EULER;
else
return EULER_A;
}
return SAMPLE_METHOD_COUNT;
}
sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
struct ggml_context* work_ctx,
ggml_tensor* init_latent,
@@ -2356,6 +2369,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
}
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_DEFAULT) {
sample_method = sd_get_default_sample_method(sd_ctx);
}
sd_image_t* result_images = generate_image_internal(sd_ctx,
work_ctx,
init_latent,
@@ -2366,7 +2384,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_img_gen_params->sample_params.eta,
width,
height,
sd_img_gen_params->sample_params.sample_method,
sample_method,
sigmas,
seed,
sd_img_gen_params->batch_count,