mirror of
https://github.com/BillyOutlast/rocm-stable-diffusion.cpp.git
synced 2026-02-04 03:01:18 +01:00
feat: overriding quant types for specific tensors on model conversion (#724)
This commit is contained in:
67
model.cpp
67
model.cpp
@@ -100,7 +100,7 @@ const char* unused_tensors[] = {
|
||||
"model_ema.diffusion_model",
|
||||
"embedding_manager",
|
||||
"denoiser.sigmas",
|
||||
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
|
||||
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
|
||||
};
|
||||
|
||||
bool is_unused_tensor(std::string name) {
|
||||
@@ -1169,7 +1169,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
|
||||
n_dims = 1;
|
||||
}
|
||||
|
||||
|
||||
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
|
||||
tensor_storage.reverse_ne();
|
||||
|
||||
@@ -1914,7 +1913,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
||||
};
|
||||
int tensor_count = 0;
|
||||
int64_t t1 = ggml_time_ms();
|
||||
bool partial = false;
|
||||
bool partial = false;
|
||||
for (auto& tensor_storage : processed_tensor_storages) {
|
||||
if (tensor_storage.file_index != file_index) {
|
||||
++tensor_count;
|
||||
@@ -1997,9 +1996,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
|
||||
}
|
||||
}
|
||||
size_t tensor_max = processed_tensor_storages.size();
|
||||
int64_t t2 = ggml_time_ms();
|
||||
int64_t t2 = ggml_time_ms();
|
||||
pretty_progress(++tensor_count, tensor_max, (t2 - t1) / 1000.0f);
|
||||
t1 = t2;
|
||||
t1 = t2;
|
||||
partial = tensor_count != tensor_max;
|
||||
}
|
||||
|
||||
@@ -2088,6 +2087,41 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
|
||||
std::vector<std::pair<std::string, ggml_type>> result;
|
||||
for (const auto& item : splitString(tensor_type_rules, ',')) {
|
||||
if (item.size() == 0)
|
||||
continue;
|
||||
std::string::size_type pos = item.find('=');
|
||||
if (pos == std::string::npos) {
|
||||
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
|
||||
continue;
|
||||
}
|
||||
std::string tensor_pattern = item.substr(0, pos);
|
||||
std::string type_name = item.substr(pos + 1);
|
||||
|
||||
ggml_type tensor_type = GGML_TYPE_COUNT;
|
||||
|
||||
if (type_name == "f32") {
|
||||
tensor_type = GGML_TYPE_F32;
|
||||
} else {
|
||||
for (size_t i = 0; i < SD_TYPE_COUNT; i++) {
|
||||
auto trait = ggml_get_type_traits((ggml_type)i);
|
||||
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
|
||||
tensor_type = (ggml_type)i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tensor_type != GGML_TYPE_COUNT) {
|
||||
result.emplace_back(tensor_pattern, tensor_type);
|
||||
} else {
|
||||
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
|
||||
const std::string& name = tensor_storage.name;
|
||||
if (type != GGML_TYPE_COUNT) {
|
||||
@@ -2119,7 +2153,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type) {
|
||||
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) {
|
||||
auto backend = ggml_backend_cpu_init();
|
||||
size_t mem_size = 1 * 1024 * 1024; // for padding
|
||||
mem_size += tensor_storages.size() * ggml_tensor_overhead();
|
||||
@@ -2129,12 +2163,23 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
|
||||
|
||||
gguf_context* gguf_ctx = gguf_init_empty();
|
||||
|
||||
auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str);
|
||||
|
||||
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
|
||||
const std::string& name = tensor_storage.name;
|
||||
ggml_type tensor_type = tensor_storage.type;
|
||||
ggml_type dst_type = type;
|
||||
|
||||
ggml_type tensor_type = tensor_storage.type;
|
||||
if (tensor_should_be_converted(tensor_storage, type)) {
|
||||
tensor_type = type;
|
||||
for (const auto& tensor_type_rule : tensor_type_rules) {
|
||||
std::regex pattern(tensor_type_rule.first);
|
||||
if (std::regex_search(name, pattern)) {
|
||||
dst_type = tensor_type_rule.second;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (tensor_should_be_converted(tensor_storage, dst_type)) {
|
||||
tensor_type = dst_type;
|
||||
}
|
||||
|
||||
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
|
||||
@@ -2193,7 +2238,7 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
|
||||
return mem_size;
|
||||
}
|
||||
|
||||
bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type) {
|
||||
bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type, const char* tensor_type_rules) {
|
||||
ModelLoader model_loader;
|
||||
|
||||
if (!model_loader.init_from_file(input_path)) {
|
||||
@@ -2207,6 +2252,6 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
|
||||
return false;
|
||||
}
|
||||
}
|
||||
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type);
|
||||
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
|
||||
return success;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user