feat: add flux support (#356)

* add flux support

* avoid build failures in non-CUDA environments

* fix schnell support

* add k quants support

* add support for applying lora to quantized tensors

* add inplace conversion support for f8_e4m3 (#359)

in the same way it is done for bf16
like how bf16 converts losslessly to fp32,
f8_e4m3 converts losslessly to fp16

* add xlabs flux comfy converted lora support

* update docs

---------

Co-authored-by: Erik Scholz <Green-Sky@users.noreply.github.com>
This commit is contained in:
leejet
2024-08-24 14:29:52 +08:00
committed by GitHub
parent 697d000f49
commit 64d231f384
25 changed files with 1886 additions and 172 deletions

157
model.cpp
View File

@@ -422,7 +422,10 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
return key;
}
std::string convert_tensor_name(const std::string& name) {
std::string convert_tensor_name(std::string name) {
if (starts_with(name, "diffusion_model")) {
name = "model." + name;
}
std::string new_name = name;
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) {
new_name = convert_open_clip_to_hf_clip(name);
@@ -554,6 +557,48 @@ float bf16_to_f32(uint16_t bfloat16) {
return *reinterpret_cast<float*>(&val_bits);
}
uint16_t f8_e4m3_to_f16(uint8_t f8) {
// do we need to support uz?
const uint32_t exponent_bias = 7;
if (f8 == 0xff) {
return ggml_fp32_to_fp16(-NAN);
} else if (f8 == 0x7f) {
return ggml_fp32_to_fp16(NAN);
}
uint32_t sign = f8 & 0x80;
uint32_t exponent = (f8 & 0x78) >> 3;
uint32_t mantissa = f8 & 0x07;
uint32_t result = sign << 24;
if (exponent == 0) {
if (mantissa > 0) {
exponent = 0x7f - exponent_bias;
// yes, 2 times
if ((mantissa & 0x04) == 0) {
mantissa &= 0x03;
mantissa <<= 1;
exponent -= 1;
}
if ((mantissa & 0x04) == 0) {
mantissa &= 0x03;
mantissa <<= 1;
exponent -= 1;
}
result |= (mantissa & 0x03) << 21;
result |= exponent << 23;
}
} else {
result |= mantissa << 20;
exponent += 0x7f - exponent_bias;
result |= exponent << 23;
}
return ggml_fp32_to_fp16(*reinterpret_cast<const float*>(&result));
}
void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
// support inplace op
for (int64_t i = n - 1; i >= 0; i--) {
@@ -561,6 +606,13 @@ void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
}
}
void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
// support inplace op
for (int64_t i = n - 1; i >= 0; i--) {
dst[i] = f8_e4m3_to_f16(src[i]);
}
}
void convert_tensor(void* src,
ggml_type src_type,
void* dst,
@@ -794,6 +846,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
ttype = GGML_TYPE_F32;
} else if (dtype == "F32") {
ttype = GGML_TYPE_F32;
} else if (dtype == "F8_E4M3") {
ttype = GGML_TYPE_F16;
}
return ttype;
}
@@ -866,7 +920,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
ggml_type type = str_to_ggml_type(dtype);
if (type == GGML_TYPE_COUNT) {
LOG_ERROR("unsupported dtype '%s'", dtype.c_str());
LOG_ERROR("unsupported dtype '%s' (tensor '%s')", dtype.c_str(), name.c_str());
return false;
}
@@ -903,6 +957,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
if (dtype == "BF16") {
tensor_storage.is_bf16 = true;
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
} else if (dtype == "F8_E4M3") {
tensor_storage.is_f8_e4m3 = true;
// f8 -> f16
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
} else {
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
}
@@ -1291,15 +1349,22 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight;
bool is_flux = false;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
return VERSION_FLUX_DEV;
}
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
}
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) {
return VERSION_3_2B;
return VERSION_SD3_2B;
}
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
return VERSION_XL;
return VERSION_SDXL;
}
if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
return VERSION_XL;
return VERSION_SDXL;
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
return VERSION_SVD;
@@ -1315,10 +1380,13 @@ SDVersion ModelLoader::get_sd_version() {
// break;
}
}
if (is_flux) {
return VERSION_FLUX_SCHNELL;
}
if (token_embedding_weight.ne[0] == 768) {
return VERSION_1_x;
return VERSION_SD1;
} else if (token_embedding_weight.ne[0] == 1024) {
return VERSION_2_x;
return VERSION_SD2;
}
return VERSION_COUNT;
}
@@ -1330,8 +1398,68 @@ ggml_type ModelLoader::get_sd_wtype() {
}
if (tensor_storage.name.find(".weight") != std::string::npos &&
(tensor_storage.name.find("time_embed") != std::string::npos) ||
tensor_storage.name.find("context_embedder") != std::string::npos) {
(tensor_storage.name.find("time_embed") != std::string::npos ||
tensor_storage.name.find("context_embedder") != std::string::npos ||
tensor_storage.name.find("time_in") != std::string::npos)) {
return tensor_storage.type;
}
}
return GGML_TYPE_COUNT;
}
ggml_type ModelLoader::get_conditioner_wtype() {
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
if ((tensor_storage.name.find("text_encoders") == std::string::npos &&
tensor_storage.name.find("cond_stage_model") == std::string::npos &&
tensor_storage.name.find("te.text_model.") == std::string::npos &&
tensor_storage.name.find("conditioner") == std::string::npos)) {
continue;
}
if (tensor_storage.name.find(".weight") != std::string::npos) {
return tensor_storage.type;
}
}
return GGML_TYPE_COUNT;
}
ggml_type ModelLoader::get_diffusion_model_wtype() {
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos) {
continue;
}
if (tensor_storage.name.find(".weight") != std::string::npos &&
(tensor_storage.name.find("time_embed") != std::string::npos ||
tensor_storage.name.find("context_embedder") != std::string::npos ||
tensor_storage.name.find("time_in") != std::string::npos)) {
return tensor_storage.type;
}
}
return GGML_TYPE_COUNT;
}
ggml_type ModelLoader::get_vae_wtype() {
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}
if (tensor_storage.name.find("vae.") == std::string::npos &&
tensor_storage.name.find("first_stage_model") == std::string::npos) {
continue;
}
if (tensor_storage.name.find(".weight")) {
return tensor_storage.type;
}
}
@@ -1467,6 +1595,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements());
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
}
} else {
read_buffer.resize(tensor_storage.nbytes());
@@ -1475,6 +1606,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
}
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
@@ -1487,6 +1621,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
if (tensor_storage.is_bf16) {
// inplace op
bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
}
if (tensor_storage.type == dst_tensor->type) {
@@ -1602,7 +1739,7 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
ggml_type tensor_type = tensor_storage.type;
if (type != GGML_TYPE_COUNT) {
if (ggml_is_quantized(type) && tensor_storage.ne[0] % 32 != 0) {
if (ggml_is_quantized(type) && tensor_storage.ne[0] % ggml_blck_size(type) != 0) {
tensor_type = GGML_TYPE_F16;
} else {
tensor_type = type;