feat: add flux support (#356)

* add flux support

* avoid build failures in non-CUDA environments

* fix schnell support

* add k quants support

* add support for applying lora to quantized tensors

* add inplace conversion support for f8_e4m3 (#359)

in the same way it is done for bf16
like how bf16 converts losslessly to fp32,
f8_e4m3 converts losslessly to fp16

* add xlabs flux comfy converted lora support

* update docs

---------

Co-authored-by: Erik Scholz <Green-Sky@users.noreply.github.com>
This commit is contained in:
leejet
2024-08-24 14:29:52 +08:00
committed by GitHub
parent 697d000f49
commit 64d231f384
25 changed files with 1886 additions and 172 deletions

View File

@@ -3,6 +3,7 @@
#include "mmdit.hpp"
#include "unet.hpp"
#include "flux.hpp"
struct DiffusionModel {
virtual void compute(int n_threads,
@@ -11,6 +12,7 @@ struct DiffusionModel {
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
int num_video_frames = -1,
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
@@ -29,7 +31,7 @@ struct UNetModel : public DiffusionModel {
UNetModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_1_x)
SDVersion version = VERSION_SD1)
: unet(backend, wtype, version) {
}
@@ -63,6 +65,7 @@ struct UNetModel : public DiffusionModel {
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
int num_video_frames = -1,
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
@@ -77,7 +80,7 @@ struct MMDiTModel : public DiffusionModel {
MMDiTModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_3_2B)
SDVersion version = VERSION_SD3_2B)
: mmdit(backend, wtype, version) {
}
@@ -111,6 +114,7 @@ struct MMDiTModel : public DiffusionModel {
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
int num_video_frames = -1,
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
@@ -120,4 +124,54 @@ struct MMDiTModel : public DiffusionModel {
}
};
struct FluxModel : public DiffusionModel {
Flux::FluxRunner flux;
FluxModel(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_FLUX_DEV)
: flux(backend, wtype, version) {
}
void alloc_params_buffer() {
flux.alloc_params_buffer();
}
void free_params_buffer() {
flux.free_params_buffer();
}
void free_compute_buffer() {
flux.free_compute_buffer();
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
flux.get_param_tensors(tensors, "model.diffusion_model");
}
size_t get_params_buffer_size() {
return flux.get_params_buffer_size();
}
int64_t get_adm_in_channels() {
return 768;
}
void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
int num_video_frames = -1,
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx);
}
};
#endif