refactor: simplify the logic of pm id image loading (#827)

This commit is contained in:
leejet
2025-09-14 22:50:21 +08:00
committed by GitHub
parent 55c2e05d98
commit 0ebe6fe118
11 changed files with 181 additions and 514 deletions

View File

@@ -412,7 +412,7 @@ public:
clip_vision->get_param_tensors(tensors);
}
} else { // SD1.x SD2.x SDXL
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
@@ -510,7 +510,7 @@ public:
}
}
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
@@ -525,15 +525,15 @@ public:
"pmid",
version);
}
if (strlen(SAFE_STR(sd_ctx_params->stacked_id_embed_dir)) > 0) {
pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->stacked_id_embed_dir, "");
if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) {
pmid_lora = std::make_shared<LoraModel>(backend, sd_ctx_params->photo_maker_path, "");
if (!pmid_lora->load_from_file(true)) {
LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->stacked_id_embed_dir);
LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path);
return false;
}
LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->stacked_id_embed_dir);
if (!model_loader.init_from_file(sd_ctx_params->stacked_id_embed_dir, "pmid.")) {
LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->stacked_id_embed_dir);
LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->photo_maker_path);
if (!model_loader.init_from_file(sd_ctx_params->photo_maker_path, "pmid.")) {
LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path);
} else {
stacked_id = true;
}
@@ -1644,7 +1644,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"control_net_path: %s\n"
"lora_model_dir: %s\n"
"embedding_dir: %s\n"
"stacked_id_embed_dir: %s\n"
"photo_maker_path: %s\n"
"vae_decode_only: %s\n"
"vae_tiling: %s\n"
"free_params_immediately: %s\n"
@@ -1671,7 +1671,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->control_net_path),
SAFE_STR(sd_ctx_params->lora_model_dir),
SAFE_STR(sd_ctx_params->embedding_dir),
SAFE_STR(sd_ctx_params->stacked_id_embed_dir),
SAFE_STR(sd_ctx_params->photo_maker_path),
BOOL_STR(sd_ctx_params->vae_decode_only),
BOOL_STR(sd_ctx_params->free_params_immediately),
sd_ctx_params->n_threads,
@@ -1747,8 +1747,8 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->seed = -1;
sd_img_gen_params->batch_count = 1;
sd_img_gen_params->control_strength = 0.9f;
sd_img_gen_params->style_strength = 20.f;
sd_img_gen_params->normalize_input = false;
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
}
@@ -1769,15 +1769,13 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
"sample_params: %s\n"
"strength: %.2f\n"
"seed: %" PRId64
"VAE tiling:"
"\n"
"batch_count: %d\n"
"ref_images_count: %d\n"
"increase_ref_index: %s\n"
"control_strength: %.2f\n"
"style_strength: %.2f\n"
"normalize_input: %s\n"
"input_id_images_path: %s\n",
"photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n"
"VAE tiling: %s\n",
SAFE_STR(sd_img_gen_params->prompt),
SAFE_STR(sd_img_gen_params->negative_prompt),
sd_img_gen_params->clip_skip,
@@ -1786,14 +1784,15 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
SAFE_STR(sample_params_str),
sd_img_gen_params->strength,
sd_img_gen_params->seed,
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled),
sd_img_gen_params->batch_count,
sd_img_gen_params->ref_images_count,
BOOL_STR(sd_img_gen_params->increase_ref_index),
sd_img_gen_params->control_strength,
sd_img_gen_params->style_strength,
BOOL_STR(sd_img_gen_params->normalize_input),
SAFE_STR(sd_img_gen_params->input_id_images_path));
sd_img_gen_params->pm_params.style_strength,
sd_img_gen_params->pm_params.id_images_count,
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled));
free(sample_params_str);
return buf;
}
@@ -1872,9 +1871,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int batch_count,
sd_image_t control_image,
float control_strength,
float style_ratio,
bool normalize_input,
std::string input_id_images_path,
sd_pm_params_t pm_params,
std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index,
ggml_tensor* concat_latent = NULL,
@@ -1915,67 +1913,46 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
}
}
// preprocess input id images
std::vector<sd_image_t*> input_id_images;
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2;
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
for (std::string img_file : img_files) {
int c = 0;
int width, height;
if (ends_with(img_file, "safetensors")) {
continue;
}
uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
if (input_image_buffer == NULL) {
LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str());
continue;
} else {
LOG_INFO("PhotoMaker loaded image from '%s'", img_file.c_str());
}
sd_image_t* input_image = NULL;
input_image = new sd_image_t{(uint32_t)width,
(uint32_t)height,
3,
input_image_buffer};
input_image = preprocess_id_image(input_image);
if (input_image == NULL) {
LOG_ERROR("preprocess input id image from '%s' failed", img_file.c_str());
continue;
}
input_id_images.push_back(input_image);
if (pm_params.id_images_count > 0) {
int clip_image_size = 224;
sd_ctx->sd->pmid_model->style_strength = pm_params.style_strength;
init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, clip_image_size, clip_image_size, 3, pm_params.id_images_count);
std::vector<sd_image_f32_t> processed_id_images;
for (int i = 0; i < pm_params.id_images_count; i++) {
sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]);
sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size);
free(id_image.data);
id_image.data = NULL;
processed_id_images.push_back(processed_id_image);
}
}
if (input_id_images.size() > 0) {
sd_ctx->sd->pmid_model->style_strength = style_ratio;
int32_t w = input_id_images[0]->width;
int32_t h = input_id_images[0]->height;
int32_t channels = input_id_images[0]->channel;
int32_t num_input_images = (int32_t)input_id_images.size();
init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, w, h, channels, num_input_images);
// TODO: move these to somewhere else and be user settable
float mean[] = {0.48145466f, 0.4578275f, 0.40821073f};
float std[] = {0.26862954f, 0.26130258f, 0.27577711f};
for (int i = 0; i < num_input_images; i++) {
sd_image_t* init_image = input_id_images[i];
if (normalize_input)
sd_mul_images_to_tensor(init_image->data, init_img, i, mean, std);
else
sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL);
ggml_tensor_iter(init_img, [&](ggml_tensor* init_img, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2);
ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3);
});
for (auto& image : processed_id_images) {
free(image.data);
image.data = NULL;
}
processed_id_images.clear();
int64_t t0 = ggml_time_ms();
auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx,
sd_ctx->sd->n_threads, prompt,
clip_skip,
width,
height,
num_input_images,
pm_params.id_images_count,
sd_ctx->sd->diffusion_model->get_adm_in_channels());
id_cond = std::get<0>(cond_tup);
class_tokens_mask = std::get<1>(cond_tup); //
struct ggml_tensor* id_embeds = NULL;
if (pmv2) {
// id_embeds = sd_ctx->sd->pmid_id_embeds->get();
id_embeds = load_tensor_from_file(work_ctx, path_join(input_id_images_path, "id_embeds.bin"));
if (pmv2 && pm_params.id_embed_path != nullptr) {
id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path);
// print_ggml_tensor(id_embeds, true, "id_embeds:");
}
id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask);
@@ -1988,19 +1965,14 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
prompt_text_only = sd_ctx->sd->cond_stage_model->remove_trigger_from_prompt(work_ctx, prompt);
// printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str());
prompt = prompt_text_only; //
// if (sample_steps < 50) {
// LOG_INFO("sampling steps increases from %d to 50 for PHOTOMAKER", sample_steps);
// sample_steps = 50;
// }
if (sample_steps < 50) {
LOG_WARN("It's recommended to use >= 50 steps for photo maker!");
}
} else {
LOG_WARN("Provided PhotoMaker model file, but NO input ID images");
LOG_WARN("Turn off PhotoMaker");
sd_ctx->sd->stacked_id = false;
}
for (sd_image_t* img : input_id_images) {
free(img->data);
}
input_id_images.clear();
}
// Get learned condition
@@ -2248,7 +2220,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
}
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(1024 * 1024) * 1024; // 1G
params.mem_size = static_cast<size_t>(1024 * 1024) * 1024; // 1G
params.mem_buffer = NULL;
params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size);
@@ -2430,9 +2402,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_img_gen_params->batch_count,
sd_img_gen_params->control_image,
sd_img_gen_params->control_strength,
sd_img_gen_params->style_strength,
sd_img_gen_params->normalize_input,
SAFE_STR(sd_img_gen_params->input_id_images_path),
sd_img_gen_params->pm_params,
ref_latents,
sd_img_gen_params->increase_ref_index,
concat_latent,