mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
feat: added prediction argument (#334)
This commit is contained in:
@@ -700,64 +700,102 @@ public:
|
||||
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM");
|
||||
}
|
||||
|
||||
// check is_using_v_parameterization_for_sd2
|
||||
if (sd_version_is_sd2(version)) {
|
||||
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
|
||||
// CosXL models
|
||||
// TODO: get sigma_min and sigma_max values from file
|
||||
is_using_edm_v_parameterization = true;
|
||||
}
|
||||
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (version == VERSION_SVD) {
|
||||
// TODO: V_PREDICTION_EDM
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
|
||||
if (sd_version_is_sd3(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
LOG_INFO("running in Flux FLOW mode");
|
||||
float shift = 1.0f; // TODO: validate
|
||||
for (auto pair : model_loader.tensor_storages_types) {
|
||||
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
|
||||
shift = 1.15f;
|
||||
if (sd_ctx_params->prediction != DEFAULT_PRED) {
|
||||
switch (sd_ctx_params->prediction) {
|
||||
case EPS_PRED:
|
||||
LOG_INFO("running in eps-prediction mode");
|
||||
break;
|
||||
case V_PRED:
|
||||
LOG_INFO("running in v-prediction mode");
|
||||
denoiser = std::make_shared<CompVisVDenoiser>();
|
||||
break;
|
||||
case EDM_V_PRED:
|
||||
LOG_INFO("running in v-prediction EDM mode");
|
||||
denoiser = std::make_shared<EDMVDenoiser>();
|
||||
break;
|
||||
case SD3_FLOW_PRED: {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
break;
|
||||
}
|
||||
case FLUX_FLOW_PRED: {
|
||||
LOG_INFO("running in Flux FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_wan(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 5.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_qwen_image(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (is_using_v_parameterization) {
|
||||
LOG_INFO("running in v-prediction mode");
|
||||
denoiser = std::make_shared<CompVisVDenoiser>();
|
||||
} else if (is_using_edm_v_parameterization) {
|
||||
LOG_INFO("running in v-prediction EDM mode");
|
||||
denoiser = std::make_shared<EDMVDenoiser>();
|
||||
} else {
|
||||
LOG_INFO("running in eps-prediction mode");
|
||||
if (sd_version_is_sd2(version)) {
|
||||
// check is_using_v_parameterization_for_sd2
|
||||
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (sd_version_is_sdxl(version)) {
|
||||
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
|
||||
// CosXL models
|
||||
// TODO: get sigma_min and sigma_max values from file
|
||||
is_using_edm_v_parameterization = true;
|
||||
}
|
||||
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
} else if (version == VERSION_SVD) {
|
||||
// TODO: V_PREDICTION_EDM
|
||||
is_using_v_parameterization = true;
|
||||
}
|
||||
|
||||
if (sd_version_is_sd3(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
LOG_INFO("running in Flux FLOW mode");
|
||||
float shift = 1.0f; // TODO: validate
|
||||
for (auto pair : model_loader.tensor_storages_types) {
|
||||
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
|
||||
shift = 1.15f;
|
||||
break;
|
||||
}
|
||||
}
|
||||
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_wan(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 5.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_qwen_image(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (is_using_v_parameterization) {
|
||||
LOG_INFO("running in v-prediction mode");
|
||||
denoiser = std::make_shared<CompVisVDenoiser>();
|
||||
} else if (is_using_edm_v_parameterization) {
|
||||
LOG_INFO("running in v-prediction EDM mode");
|
||||
denoiser = std::make_shared<EDMVDenoiser>();
|
||||
} else {
|
||||
LOG_INFO("running in eps-prediction mode");
|
||||
}
|
||||
}
|
||||
|
||||
auto comp_vis_denoiser = std::dynamic_pointer_cast<CompVisDenoiser>(denoiser);
|
||||
@@ -1742,6 +1780,31 @@ enum scheduler_t str_to_schedule(const char* str) {
|
||||
return SCHEDULE_COUNT;
|
||||
}
|
||||
|
||||
const char* prediction_to_str[] = {
|
||||
"default",
|
||||
"eps",
|
||||
"v",
|
||||
"edm_v",
|
||||
"sd3_flow",
|
||||
"flux_flow",
|
||||
};
|
||||
|
||||
const char* sd_prediction_name(enum prediction_t prediction) {
|
||||
if (prediction < PREDICTION_COUNT) {
|
||||
return prediction_to_str[prediction];
|
||||
}
|
||||
return NONE_STR;
|
||||
}
|
||||
|
||||
enum prediction_t str_to_prediction(const char* str) {
|
||||
for (int i = 0; i < PREDICTION_COUNT; i++) {
|
||||
if (!strcmp(str, prediction_to_str[i])) {
|
||||
return (enum prediction_t)i;
|
||||
}
|
||||
}
|
||||
return PREDICTION_COUNT;
|
||||
}
|
||||
|
||||
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
|
||||
*sd_ctx_params = {};
|
||||
sd_ctx_params->vae_decode_only = true;
|
||||
@@ -1749,6 +1812,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
|
||||
sd_ctx_params->n_threads = get_num_physical_cores();
|
||||
sd_ctx_params->wtype = SD_TYPE_COUNT;
|
||||
sd_ctx_params->rng_type = CUDA_RNG;
|
||||
sd_ctx_params->prediction = DEFAULT_PRED;
|
||||
sd_ctx_params->offload_params_to_cpu = false;
|
||||
sd_ctx_params->keep_clip_on_cpu = false;
|
||||
sd_ctx_params->keep_control_net_on_cpu = false;
|
||||
@@ -1788,6 +1852,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
"n_threads: %d\n"
|
||||
"wtype: %s\n"
|
||||
"rng_type: %s\n"
|
||||
"prediction: %s\n"
|
||||
"offload_params_to_cpu: %s\n"
|
||||
"keep_clip_on_cpu: %s\n"
|
||||
"keep_control_net_on_cpu: %s\n"
|
||||
@@ -1816,6 +1881,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
sd_ctx_params->n_threads,
|
||||
sd_type_name(sd_ctx_params->wtype),
|
||||
sd_rng_type_name(sd_ctx_params->rng_type),
|
||||
sd_prediction_name(sd_ctx_params->prediction),
|
||||
BOOL_STR(sd_ctx_params->offload_params_to_cpu),
|
||||
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
|
||||
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
|
||||
|
||||
Reference in New Issue
Block a user