refector: optimize the usage of tensor_types

This commit is contained in:
leejet
2025-07-28 23:18:29 +08:00
parent 7eb30d00e5
commit f6b9aa1a43
16 changed files with 119 additions and 111 deletions

View File

@@ -32,9 +32,9 @@ struct UNetModel : public DiffusionModel {
UNetModelRunner unet;
UNetModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
SDVersion version = VERSION_SD1,
bool flash_attn = false)
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_SD1,
bool flash_attn = false)
: unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
}
@@ -85,7 +85,7 @@ struct MMDiTModel : public DiffusionModel {
MMDiTRunner mmdit;
MMDiTModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types)
const String2GGMLType& tensor_types = {})
: mmdit(backend, tensor_types, "model.diffusion_model") {
}
@@ -135,10 +135,10 @@ struct FluxModel : public DiffusionModel {
Flux::FluxRunner flux;
FluxModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
const String2GGMLType& tensor_types = {},
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
}