mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
refactor: simplify the logic of pm id image loading (#827)
This commit is contained in:
@@ -412,7 +412,7 @@ public:
|
||||
clip_vision->get_param_tensors(tensors);
|
||||
}
|
||||
} else { // SD1.x SD2.x SDXL
|
||||
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
|
||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
@@ -510,7 +510,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
|
||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
@@ -525,15 +525,15 @@ public:
|
||||
"pmid",
|
||||
version);
|
||||
}
|
||||
if (strlen(SAFE_STR(sd_ctx_params->stacked_id_embed_dir)) > 0) {
|
||||
pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->stacked_id_embed_dir, "");
|
||||
if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) {
|
||||
pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->photo_maker_path, "");
|
||||
if (!pmid_lora->load_from_file(true)) {
|
||||
LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->stacked_id_embed_dir);
|
||||
LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path);
|
||||
return false;
|
||||
}
|
||||
LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->stacked_id_embed_dir);
|
||||
if (!model_loader.init_from_file(sd_ctx_params->stacked_id_embed_dir, "pmid.")) {
|
||||
LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->stacked_id_embed_dir);
|
||||
LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->photo_maker_path);
|
||||
if (!model_loader.init_from_file(sd_ctx_params->photo_maker_path, "pmid.")) {
|
||||
LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path);
|
||||
} else {
|
||||
stacked_id = true;
|
||||
}
|
||||
@@ -1644,7 +1644,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
"control_net_path: %s\n"
|
||||
"lora_model_dir: %s\n"
|
||||
"embedding_dir: %s\n"
|
||||
"stacked_id_embed_dir: %s\n"
|
||||
"photo_maker_path: %s\n"
|
||||
"vae_decode_only: %s\n"
|
||||
"vae_tiling: %s\n"
|
||||
"free_params_immediately: %s\n"
|
||||
@@ -1671,7 +1671,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
SAFE_STR(sd_ctx_params->control_net_path),
|
||||
SAFE_STR(sd_ctx_params->lora_model_dir),
|
||||
SAFE_STR(sd_ctx_params->embedding_dir),
|
||||
SAFE_STR(sd_ctx_params->stacked_id_embed_dir),
|
||||
SAFE_STR(sd_ctx_params->photo_maker_path),
|
||||
BOOL_STR(sd_ctx_params->vae_decode_only),
|
||||
BOOL_STR(sd_ctx_params->free_params_immediately),
|
||||
sd_ctx_params->n_threads,
|
||||
@@ -1747,8 +1747,8 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
|
||||
sd_img_gen_params->seed = -1;
|
||||
sd_img_gen_params->batch_count = 1;
|
||||
sd_img_gen_params->control_strength = 0.9f;
|
||||
sd_img_gen_params->style_strength = 20.f;
|
||||
sd_img_gen_params->normalize_input = false;
|
||||
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
|
||||
sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
}
|
||||
|
||||
@@ -1769,15 +1769,13 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
"sample_params: %s\n"
|
||||
"strength: %.2f\n"
|
||||
"seed: %" PRId64
|
||||
"VAE tiling:"
|
||||
"\n"
|
||||
"batch_count: %d\n"
|
||||
"ref_images_count: %d\n"
|
||||
"increase_ref_index: %s\n"
|
||||
"control_strength: %.2f\n"
|
||||
"style_strength: %.2f\n"
|
||||
"normalize_input: %s\n"
|
||||
"input_id_images_path: %s\n",
|
||||
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
|
||||
"VAE tiling: %s\n",
|
||||
SAFE_STR(sd_img_gen_params->prompt),
|
||||
SAFE_STR(sd_img_gen_params->negative_prompt),
|
||||
sd_img_gen_params->clip_skip,
|
||||
@@ -1786,14 +1784,15 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
SAFE_STR(sample_params_str),
|
||||
sd_img_gen_params->strength,
|
||||
sd_img_gen_params->seed,
|
||||
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
|
||||
sd_img_gen_params->batch_count,
|
||||
sd_img_gen_params->ref_images_count,
|
||||
BOOL_STR(sd_img_gen_params->increase_ref_index),
|
||||
sd_img_gen_params->control_strength,
|
||||
sd_img_gen_params->style_strength,
|
||||
BOOL_STR(sd_img_gen_params->normalize_input),
|
||||
SAFE_STR(sd_img_gen_params->input_id_images_path));
|
||||
sd_img_gen_params->pm_params.style_strength,
|
||||
sd_img_gen_params->pm_params.id_images_count,
|
||||
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
|
||||
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled));
|
||||
free(sample_params_str);
|
||||
return buf;
|
||||
}
|
||||
@@ -1872,9 +1871,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
int batch_count,
|
||||
sd_image_t control_image,
|
||||
float control_strength,
|
||||
float style_ratio,
|
||||
bool normalize_input,
|
||||
std::string input_id_images_path,
|
||||
sd_pm_params_t pm_params,
|
||||
std::vector<ggml_tensor*> ref_latents,
|
||||
bool increase_ref_index,
|
||||
ggml_tensor* concat_latent = NULL,
|
||||
@@ -1915,67 +1913,46 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
}
|
||||
}
|
||||
// preprocess input id images
|
||||
std::vector<sd_image_t*> input_id_images;
|
||||
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2;
|
||||
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
|
||||
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
|
||||
for (std::string img_file : img_files) {
|
||||
int c = 0;
|
||||
int width, height;
|
||||
if (ends_with(img_file, "safetensors")) {
|
||||
continue;
|
||||
}
|
||||
uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
|
||||
if (input_image_buffer == NULL) {
|
||||
LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str());
|
||||
continue;
|
||||
} else {
|
||||
LOG_INFO("PhotoMaker loaded image from '%s'", img_file.c_str());
|
||||
}
|
||||
sd_image_t* input_image = NULL;
|
||||
input_image = new sd_image_t{(uint32_t)width,
|
||||
(uint32_t)height,
|
||||
3,
|
||||
input_image_buffer};
|
||||
input_image = preprocess_id_image(input_image);
|
||||
if (input_image == NULL) {
|
||||
LOG_ERROR("preprocess input id image from '%s' failed", img_file.c_str());
|
||||
continue;
|
||||
}
|
||||
input_id_images.push_back(input_image);
|
||||
if (pm_params.id_images_count > 0) {
|
||||
int clip_image_size = 224;
|
||||
sd_ctx->sd->pmid_model->style_strength = pm_params.style_strength;
|
||||
|
||||
init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, clip_image_size, clip_image_size, 3, pm_params.id_images_count);
|
||||
|
||||
std::vector<sd_image_f32_t> processed_id_images;
|
||||
for (int i = 0; i < pm_params.id_images_count; i++) {
|
||||
sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]);
|
||||
sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size);
|
||||
free(id_image.data);
|
||||
id_image.data = NULL;
|
||||
processed_id_images.push_back(processed_id_image);
|
||||
}
|
||||
}
|
||||
if (input_id_images.size() > 0) {
|
||||
sd_ctx->sd->pmid_model->style_strength = style_ratio;
|
||||
int32_t w = input_id_images[0]->width;
|
||||
int32_t h = input_id_images[0]->height;
|
||||
int32_t channels = input_id_images[0]->channel;
|
||||
int32_t num_input_images = (int32_t)input_id_images.size();
|
||||
init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, w, h, channels, num_input_images);
|
||||
// TODO: move these to somewhere else and be user settable
|
||||
float mean[] = {0.48145466f, 0.4578275f, 0.40821073f};
|
||||
float std[] = {0.26862954f, 0.26130258f, 0.27577711f};
|
||||
for (int i = 0; i < num_input_images; i++) {
|
||||
sd_image_t* init_image = input_id_images[i];
|
||||
if (normalize_input)
|
||||
sd_mul_images_to_tensor(init_image->data, init_img, i, mean, std);
|
||||
else
|
||||
sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL);
|
||||
|
||||
ggml_tensor_iter(init_img, [&](ggml_tensor* init_img, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2);
|
||||
ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3);
|
||||
});
|
||||
|
||||
for (auto& image : processed_id_images) {
|
||||
free(image.data);
|
||||
image.data = NULL;
|
||||
}
|
||||
processed_id_images.clear();
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
|
||||
sd_ctx->sd->n_threads, prompt,
|
||||
clip_skip,
|
||||
width,
|
||||
height,
|
||||
num_input_images,
|
||||
pm_params.id_images_count,
|
||||
sd_ctx->sd->diffusion_model->get_adm_in_channels());
|
||||
id_cond = std::get<0>(cond_tup);
|
||||
class_tokens_mask = std::get<1>(cond_tup); //
|
||||
struct ggml_tensor* id_embeds = NULL;
|
||||
if (pmv2) {
|
||||
// id_embeds = sd_ctx->sd->pmid_id_embeds->get();
|
||||
id_embeds = load_tensor_from_file(work_ctx, path_join(input_id_images_path, "id_embeds.bin"));
|
||||
if (pmv2 && pm_params.id_embed_path != nullptr) {
|
||||
id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path);
|
||||
// print_ggml_tensor(id_embeds, true, "id_embeds:");
|
||||
}
|
||||
id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask);
|
||||
@@ -1988,19 +1965,14 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
prompt_text_only = sd_ctx->sd->cond_stage_model->remove_trigger_from_prompt(work_ctx, prompt);
|
||||
// printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str());
|
||||
prompt = prompt_text_only; //
|
||||
// if (sample_steps < 50) {
|
||||
// LOG_INFO("sampling steps increases from %d to 50 for PHOTOMAKER", sample_steps);
|
||||
// sample_steps = 50;
|
||||
// }
|
||||
if (sample_steps < 50) {
|
||||
LOG_WARN("It's recommended to use >= 50 steps for photo maker!");
|
||||
}
|
||||
} else {
|
||||
LOG_WARN("Provided PhotoMaker model file, but NO input ID images");
|
||||
LOG_WARN("Turn off PhotoMaker");
|
||||
sd_ctx->sd->stacked_id = false;
|
||||
}
|
||||
for (sd_image_t* img : input_id_images) {
|
||||
free(img->data);
|
||||
}
|
||||
input_id_images.clear();
|
||||
}
|
||||
|
||||
// Get learned condition
|
||||
@@ -2248,7 +2220,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
||||
}
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(1024 * 1024) * 1024; // 1G
|
||||
params.mem_size = static_cast<size_t>(1024 * 1024) * 1024; // 1G
|
||||
params.mem_buffer = NULL;
|
||||
params.no_alloc = false;
|
||||
// LOG_DEBUG("mem_size %u ", params.mem_size);
|
||||
@@ -2430,9 +2402,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
||||
sd_img_gen_params->batch_count,
|
||||
sd_img_gen_params->control_image,
|
||||
sd_img_gen_params->control_strength,
|
||||
sd_img_gen_params->style_strength,
|
||||
sd_img_gen_params->normalize_input,
|
||||
SAFE_STR(sd_img_gen_params->input_id_images_path),
|
||||
sd_img_gen_params->pm_params,
|
||||
ref_latents,
|
||||
sd_img_gen_params->increase_ref_index,
|
||||
concat_latent,
|
||||
|
||||
Reference in New Issue
Block a user