mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
feat: use Euler sampling by default for SD3 and Flux (#753)
Thank you for your contribution.
This commit is contained in:
@@ -43,7 +43,7 @@ const char* model_version_to_str[] = {
|
||||
};
|
||||
|
||||
const char* sampling_methods_str[] = {
|
||||
"Euler A",
|
||||
"default",
|
||||
"Euler",
|
||||
"Heun",
|
||||
"DPM2",
|
||||
@@ -55,6 +55,7 @@ const char* sampling_methods_str[] = {
|
||||
"LCM",
|
||||
"DDIM \"trailing\"",
|
||||
"TCD",
|
||||
"Euler A",
|
||||
};
|
||||
|
||||
/*================================================== Helper Functions ================================================*/
|
||||
@@ -1500,7 +1501,7 @@ enum rng_type_t str_to_rng_type(const char* str) {
|
||||
}
|
||||
|
||||
const char* sample_method_to_str[] = {
|
||||
"euler_a",
|
||||
"default",
|
||||
"euler",
|
||||
"heun",
|
||||
"dpm2",
|
||||
@@ -1512,6 +1513,7 @@ const char* sample_method_to_str[] = {
|
||||
"lcm",
|
||||
"ddim_trailing",
|
||||
"tcd",
|
||||
"euler_a",
|
||||
};
|
||||
|
||||
const char* sd_sample_method_name(enum sample_method_t sample_method) {
|
||||
@@ -1650,7 +1652,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
|
||||
sample_params->guidance.slg.layer_end = 0.2f;
|
||||
sample_params->guidance.slg.scale = 0.f;
|
||||
sample_params->scheduler = DEFAULT;
|
||||
sample_params->sample_method = EULER_A;
|
||||
sample_params->sample_method = SAMPLE_METHOD_DEFAULT;
|
||||
sample_params->sample_steps = 20;
|
||||
}
|
||||
|
||||
@@ -1792,6 +1794,17 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
|
||||
free(sd_ctx);
|
||||
}
|
||||
|
||||
enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
|
||||
if (sd_ctx != NULL && sd_ctx->sd != NULL) {
|
||||
SDVersion version = sd_ctx->sd->version;
|
||||
if (sd_version_is_dit(version))
|
||||
return EULER;
|
||||
else
|
||||
return EULER_A;
|
||||
}
|
||||
return SAMPLE_METHOD_COUNT;
|
||||
}
|
||||
|
||||
sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
struct ggml_context* work_ctx,
|
||||
ggml_tensor* init_latent,
|
||||
@@ -2356,6 +2369,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
||||
LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
||||
}
|
||||
|
||||
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
|
||||
if (sample_method == SAMPLE_METHOD_DEFAULT) {
|
||||
sample_method = sd_get_default_sample_method(sd_ctx);
|
||||
}
|
||||
|
||||
sd_image_t* result_images = generate_image_internal(sd_ctx,
|
||||
work_ctx,
|
||||
init_latent,
|
||||
@@ -2366,7 +2384,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
||||
sd_img_gen_params->sample_params.eta,
|
||||
width,
|
||||
height,
|
||||
sd_img_gen_params->sample_params.sample_method,
|
||||
sample_method,
|
||||
sigmas,
|
||||
seed,
|
||||
sd_img_gen_params->batch_count,
|
||||
|
||||
Reference in New Issue
Block a user