mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
feat: support independent sampler rng (#978)
This commit is contained in:
@@ -99,10 +99,11 @@ public:
|
||||
bool vae_decode_only = false;
|
||||
bool free_params_immediately = false;
|
||||
|
||||
std::shared_ptr<RNG> rng = std::make_shared<STDDefaultRNG>();
|
||||
int n_threads = -1;
|
||||
float scale_factor = 0.18215f;
|
||||
float shift_factor = 0.f;
|
||||
std::shared_ptr<RNG> rng = std::make_shared<PhiloxRNG>();
|
||||
std::shared_ptr<RNG> sampler_rng = nullptr;
|
||||
int n_threads = -1;
|
||||
float scale_factor = 0.18215f;
|
||||
float shift_factor = 0.f;
|
||||
|
||||
std::shared_ptr<Conditioner> cond_stage_model;
|
||||
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
|
||||
@@ -188,6 +189,16 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<RNG> get_rng(rng_type_t rng_type) {
|
||||
if (rng_type == STD_DEFAULT_RNG) {
|
||||
return std::make_shared<STDDefaultRNG>();
|
||||
} else if (rng_type == CPU_RNG) {
|
||||
return std::make_shared<MT19937RNG>();
|
||||
} else { // default: CUDA_RNG
|
||||
return std::make_shared<PhiloxRNG>();
|
||||
}
|
||||
}
|
||||
|
||||
bool init(const sd_ctx_params_t* sd_ctx_params) {
|
||||
n_threads = sd_ctx_params->n_threads;
|
||||
vae_decode_only = sd_ctx_params->vae_decode_only;
|
||||
@@ -197,12 +208,11 @@ public:
|
||||
use_tiny_autoencoder = taesd_path.size() > 0;
|
||||
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
|
||||
|
||||
if (sd_ctx_params->rng_type == STD_DEFAULT_RNG) {
|
||||
rng = std::make_shared<STDDefaultRNG>();
|
||||
} else if (sd_ctx_params->rng_type == CUDA_RNG) {
|
||||
rng = std::make_shared<PhiloxRNG>();
|
||||
} else if (sd_ctx_params->rng_type == CPU_RNG) {
|
||||
rng = std::make_shared<MT19937RNG>();
|
||||
rng = get_rng(sd_ctx_params->rng_type);
|
||||
if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT) {
|
||||
sampler_rng = get_rng(sd_ctx_params->sampler_rng_type);
|
||||
} else {
|
||||
sampler_rng = rng;
|
||||
}
|
||||
|
||||
ggml_log_set(ggml_log_callback_default, nullptr);
|
||||
@@ -1736,7 +1746,7 @@ public:
|
||||
return denoised;
|
||||
};
|
||||
|
||||
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
|
||||
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
|
||||
|
||||
if (inverse_noise_scaling) {
|
||||
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
||||
@@ -2291,6 +2301,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
|
||||
sd_ctx_params->n_threads = get_num_physical_cores();
|
||||
sd_ctx_params->wtype = SD_TYPE_COUNT;
|
||||
sd_ctx_params->rng_type = CUDA_RNG;
|
||||
sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT;
|
||||
sd_ctx_params->prediction = DEFAULT_PRED;
|
||||
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
|
||||
sd_ctx_params->offload_params_to_cpu = false;
|
||||
@@ -2332,6 +2343,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
"n_threads: %d\n"
|
||||
"wtype: %s\n"
|
||||
"rng_type: %s\n"
|
||||
"sampler_rng_type: %s\n"
|
||||
"prediction: %s\n"
|
||||
"offload_params_to_cpu: %s\n"
|
||||
"keep_clip_on_cpu: %s\n"
|
||||
@@ -2362,6 +2374,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
sd_ctx_params->n_threads,
|
||||
sd_type_name(sd_ctx_params->wtype),
|
||||
sd_rng_type_name(sd_ctx_params->rng_type),
|
||||
sd_rng_type_name(sd_ctx_params->sampler_rng_type),
|
||||
sd_prediction_name(sd_ctx_params->prediction),
|
||||
BOOL_STR(sd_ctx_params->offload_params_to_cpu),
|
||||
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
|
||||
@@ -2823,6 +2836,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, batch_count, cur_seed);
|
||||
|
||||
sd_ctx->sd->rng->manual_seed(cur_seed);
|
||||
sd_ctx->sd->sampler_rng->manual_seed(cur_seed);
|
||||
struct ggml_tensor* x_t = init_latent;
|
||||
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||
ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng);
|
||||
@@ -2949,6 +2963,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
||||
seed = rand();
|
||||
}
|
||||
sd_ctx->sd->rng->manual_seed(seed);
|
||||
sd_ctx->sd->sampler_rng->manual_seed(seed);
|
||||
|
||||
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
|
||||
|
||||
@@ -3240,6 +3255,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
}
|
||||
|
||||
sd_ctx->sd->rng->manual_seed(seed);
|
||||
sd_ctx->sd->sampler_rng->manual_seed(seed);
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user