feat: add support for Flux Controls and Flex.2 (#692)

This commit is contained in:
stduhpf
2025-10-10 18:06:57 +02:00
committed by GitHub
parent 35843c77ea
commit 11f436c483
7 changed files with 156 additions and 34 deletions

View File

@@ -37,6 +37,8 @@ const char* model_version_to_str[] = {
"SD3.x",
"Flux",
"Flux Fill",
"Flux Control",
"Flex.2",
"Wan 2.x",
"Wan 2.2 I2V",
"Wan 2.2 TI2V",
@@ -102,7 +104,7 @@ public:
std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
std::shared_ptr<VAE> first_stage_model;
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
std::shared_ptr<ControlNet> control_net;
std::shared_ptr<ControlNet> control_net = NULL;
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
std::shared_ptr<LoraModel> pmid_lora;
std::shared_ptr<PhotoMakerIDEmbed> pmid_id_embeds;
@@ -320,6 +322,11 @@ public:
scale_factor = 1.0f;
}
if (sd_version_is_control(version)) {
// Might need vae encode for control cond
vae_decode_only = false;
}
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
{
@@ -1147,7 +1154,7 @@ public:
std::vector<struct ggml_tensor*> controls;
if (control_hint != NULL) {
if (control_hint != NULL && control_net != NULL) {
control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector);
controls = control_net->controls;
// print_ggml_tensor(controls[12]);
@@ -1185,7 +1192,7 @@ public:
float* negative_data = NULL;
if (has_unconditioned) {
// uncond
if (control_hint != NULL) {
if (control_hint != NULL && control_net != NULL) {
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
controls = control_net->controls;
}
@@ -2070,10 +2077,24 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int W = width / 8;
int H = height / 8;
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
struct ggml_tensor* control_latent = NULL;
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) {
if (!sd_ctx->sd->use_tiny_autoencoder) {
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
} else {
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
}
ggml_tensor_scale(control_latent, control_strength);
}
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
mask_channels = 1 + init_latent->ne[2];
}
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
// no mask, set the whole image as masked
@@ -2087,6 +2108,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) {
ggml_tensor_set_f32(empty_latent, 1, x, y, c);
}
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
for (int64_t c = 0; c < empty_latent->ne[2]; c++) {
// 0x16,1x1,0x16
ggml_tensor_set_f32(empty_latent, c == init_latent->ne[2], x, y, c);
}
} else {
ggml_tensor_set_f32(empty_latent, 1, x, y, 0);
for (int64_t c = 1; c < empty_latent->ne[2]; c++) {
@@ -2095,7 +2121,28 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
}
}
}
if (concat_latent == NULL) {
if (sd_ctx->sd->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
bool no_inpaint = concat_latent == NULL;
if (no_inpaint) {
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
}
// fill in the control image here
for (int64_t x = 0; x < control_latent->ne[0]; x++) {
for (int64_t y = 0; y < control_latent->ne[1]; y++) {
if (no_inpaint) {
for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) {
// 0x16,1x1,0x16
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
}
}
for (int64_t c = 0; c < control_latent->ne[2]; c++) {
float v = ggml_tensor_get_f32(control_latent, x, y, c);
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c);
}
}
}
} else if (concat_latent == NULL) {
concat_latent = empty_latent;
}
cond.c_concat = concat_latent;
@@ -2105,10 +2152,20 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
ggml_set_f32(empty_latent, 0);
uncond.c_concat = empty_latent;
if (concat_latent == NULL) {
concat_latent = empty_latent;
cond.c_concat = ref_latents[0];
if (cond.c_concat == NULL) {
cond.c_concat = empty_latent;
}
} else if (sd_version_is_control(sd_ctx->sd->version)) {
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
ggml_set_f32(empty_latent, 0);
uncond.c_concat = empty_latent;
if (sd_ctx->sd->control_net == NULL) {
cond.c_concat = control_latent;
}
if (cond.c_concat == NULL) {
cond.c_concat = empty_latent;
}
cond.c_concat = ref_latents[0];
}
SDCondition img_cond;
if (uncond.c_crossattn != NULL &&
@@ -2291,6 +2348,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
ggml_tensor* init_latent = NULL;
ggml_tensor* init_moments = NULL;
ggml_tensor* concat_latent = NULL;
ggml_tensor* denoise_mask = NULL;
if (sd_img_gen_params->init_image.data) {
@@ -2310,19 +2368,35 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_image_to_tensor(sd_img_gen_params->mask_image, mask_img);
sd_image_to_tensor(sd_img_gen_params->init_image, init_img);
if (!sd_ctx->sd->use_tiny_autoencoder) {
init_moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments);
} else {
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
}
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
mask_channels = 1 + init_latent->ne[2];
}
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_apply_mask(init_img, mask_img, masked_img);
ggml_tensor* masked_latent = NULL;
if (!sd_ctx->sd->use_tiny_autoencoder) {
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
if (sd_ctx->sd->version != VERSION_FLEX_2) {
// most inpaint models mask before vae
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_apply_mask(init_img, mask_img, masked_img);
if (!sd_ctx->sd->use_tiny_autoencoder) {
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
} else {
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
}
} else {
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
// mask after vae
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
sd_apply_mask(init_latent, mask_img, masked_latent, 0.);
}
concat_latent = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
@@ -2348,12 +2422,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y);
}
}
} else {
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
float m = ggml_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(concat_latent, m, ix, iy, 0);
// masked image
for (int k = 0; k < masked_latent->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
ggml_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels);
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
}
// downsampled mask
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
// control (todo: support this)
for (int k = 0; k < masked_latent->ne[2]; k++) {
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
}
}
}
@@ -2373,12 +2453,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
}
}
if (!sd_ctx->sd->use_tiny_autoencoder) {
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
} else {
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
}
} else {
LOG_INFO("TXT2IMG");
if (sd_version_is_inpaint(sd_ctx->sd->version)) {