model: Add support for GLM 4.5 family of models (#14921)#14939
model: Add support for GLM 4.5 family of models (#14921)#14939CISC merged 13 commits intoggml-org:masterfrom sammcj:glm-4-5
Conversation
|
Just a few quick notes from a glance:
Will do a proper review when you are ready. :) |
|
Hey @CISC no worries on the naming etc.. will do. |
|
FYI when trying to run |
That's because converting FP8 weights isn't supported yet, see #14810 |
|
I'm close to having convert_hf_to_gguf.py and llama-quantize working (see updated PR), it completes conversion without error and I was then able to quantise to Q4_K_M. gguf-dump worked, but llama-server picked up a tensor mapping issue with token_embd.weight, so I've just put a fix into convert_hf_to_gguf.py. I'm going through the whole conversion then quantisation process again, it's getting late here (Hi from Melbourne 👋), so I'll come back and see if it's finished in 20~. |
|
The LLM_TYPE code is wrong, those models aren't (respectively) dense 12B and 32B models. You have to add new MoE constants for them (see Qwen3 and Ernie MoEs as examples). |
|
Also, you might want to include the nextn tensors instead of throwing them out - MTP support is not there yet, but that way you won't have to reconvert and requantize if/when it arrives. |
|
Thanks @pwilkin, LLM_TYPE updated. I've added the nextn tensors into the conversion, skipping mapping to avoid errors. |
|
Note that preserving the nextn tensors does result in a larger GGUF (780 tensors -> 1184 & 214GB -> 221GB for the f16) |
|
I can't replicate that error @Thireus |
Obviously, but they won't get loaded since they're not supported 😄 Also, don't make my mistake: Don't convert to f16, do --outtype bf16 or your model will probably have errors in the tensors. |
|
If you add unused tensors to the GGUF you must mark those tensors as unused ( Just FYI, all other models with MTP so far have those tensors stripped. |
Ah, that'd explain why I'm getting I'll have to come back to this in the morning as it's getting late here. If anyone is keen for this ASAP and has improvements feel free to either raise a PR against my branch or pull my commits into a PR of your own if you have a better approach and I'll review in the morning. |
|
I'll just put it out there right now; no-one should make GGUFs from this PR public yet, there will be changes! :) |
Absolutely, I hope people do not do that - it's very much in draft and I'm learning as I go. |
|
@sammcj, 7f026fb#diff-4f653096980bd7d10518aa909cb648452cd3aa380ff93cb9fb642dca48536526 fixed the issue thanks. |
|
the fix seems to work, still testing -> INFO:hf-to-gguf:Model successfully exported to models/glm-45-air-f16.gguf |
This comment was marked as off-topic.
This comment was marked as off-topic.
|
@Thireus are you sure that GGUF conversion is complete and correct? By uploading potentially broken conversions to HF you could be causing a lot of people grief and wasted bandwidth. |
This comment was marked as off-topic.
This comment was marked as off-topic.
|
|
|
Folks, this PR is in DRAFT it is not expected to work yet, when it does I will move it out of draft. Until then unless you have a code change to recommend, please hold other comments until its ready. |
Sorry, sometimes people want help testing, so I chimed in. |
|
No worries at all, sorry if my message came across angry (it wasn't intended to), I just want to make sure that folks don't end up wasting their time or bandwidth, or worse - blame the llama.cpp project when it's my ability that's lacking. The best help if anyone wants to add it right now would be in code corrections (either point out the required change here or in a PR to my branch @'ing me so I get a notification. Thanks for wanting to help out, I promise as soon as I think it's in a state that it could be tested I will update this thread. |
|
No pressure. We're just all excited to try the new model. Thanks for doing the hard work! |
|
Bingo! I have conversion, quantisation and llama-server working! 🎉 Please feel free to test this out and if you have code changes to suggest - please do those here. Note: If you end up sharing any GGUFs built from this PR - PLEASE make it clear that they're built from a llama.cpp PR (aka an unofficial fork) and that there may be changes before it's stable. |
|
Thanks! |
|
@CISC I narrowed down the gibberish issue a bit. It requires setting --batch-size 4096 --ubatch-size 4096 and possibly having a long multi-turn chat going. When I removed the batch-size / ubatch-size, my 40k and 50k token chats began working again. Setting the sizes up to 2048 / 2048 also worked. Something about 4096 / 4096 combined with over 32k context across multiple turns leads to that gibberish edge case. I also tried a needle in a haystack test with a 35k token prompt with a direction to answer a question from the text as a one-shot and that worked. So I don't have a reproducible smoking gun, but batch-size / ubatch-size is involved and for now I'm just scaling them back to make it work. |
Ah, ok, so that means it's not a model issue then, that's great! Submit an issue though. :) |
|
Just FYI for anyone wanting to create i-quants; as the final layer will not get imatrix data until MTP is supported it has to be overridden for lower quants to work, eg. using |
|
I am getting over 45t/s on three 3090s on unsloth quant Q4 for GLM Air, here is the optimized command: |
I can confirm it's not warming up. Manually setting If I patch uint32_t llama_context::graph_max_nodes() const {
//return std::max<uint32_t>(1024u, 8u*model.n_tensors());
return std::max<uint32_t>(65536u, 8u*model.n_tensors());
} and then run with You then need to rerun without I've got to go out so no more time to investigate until later. |
|
Actually, no it's still not warming up properly - it's just a lot quicker because it's got the experts mmapped I think... Will see if I can figure it out later if nobody else has by then. |
|
I've found it: // MoE layer with shared experts
//const int64_t n_expert = hparams.n_expert;
//const int64_t n_expert_used = hparams.n_expert_used;
// Process routed experts using existing MoE infrastructure
ggml_tensor * routed_out = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
model.layers[il].ffn_exp_probs_b,
n_expert, n_expert_used,
LLM_FFN_SILU, hparams.expert_weights_norm,
true, hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
cb(routed_out, "ffn_moe_out", il);The local llm_graph_context::llm_graph_context(const llm_graph_params & params) :
arch (params.arch),
hparams (params.hparams),
cparams (params.cparams),
ubatch (params.ubatch),
n_embd (hparams.n_embd),
n_layer (hparams.n_layer),
n_rot (hparams.n_rot),
n_ctx (cparams.n_ctx),
n_head (hparams.n_head()),
n_head_kv (hparams.n_head_kv()),
n_embd_head_k (hparams.n_embd_head_k),
n_embd_k_gqa (hparams.n_embd_k_gqa()),
n_embd_head_v (hparams.n_embd_head_v),
n_embd_v_gqa (hparams.n_embd_v_gqa()),
n_expert (hparams.n_expert),
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale),
ext_factor (cparams.yarn_ext_factor),
attn_factor (cparams.yarn_attn_factor),
beta_fast (cparams.yarn_beta_fast),
beta_slow (cparams.yarn_beta_slow),
norm_eps (hparams.f_norm_eps),
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (ubatch.n_tokens),
n_outputs (params.n_outputs),
n_ctx_orig (cparams.n_ctx_orig_yarn),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
sched (params.sched),
backend_cpu (params.backend_cpu),
cvec (params.cvec),
loras (params.loras),
mctx (params.mctx),
cross (params.cross),
cb_func (params.cb),
res (params.res),
ctx0 (res->get_ctx()),
gf (res->get_gf()) {
res->set_params(params);
} |
|
@jukofyork confirmed. This fixes warmup for me. It also restores the GLM-4.5 to the performance levels I've come to expect from llama.cpp:
Startup command: Details./build/bin/llama-server \
--model /data/GLM-4.5-GGUF/q4_k_m/GLM-4.5-Q4_K_M.gguf \
--alias GLM-4.5-GGUF:q4_k_m \
--no-webui \
--numa numactl \
--threads 32 \
--ctx-size 131072 \
--n-gpu-layers 94 \
-ot "blk\.(3|4|5|6|7|8|9|10|11|12|13|14|15|16|17)\.ffn_.*=CUDA0" \
-ot exps=CPU \
-ub 4096 -b 4096 \
--seed 3407 \
--temp 0.6 \
--top-p 1.0 \
--log-colors \
--flash-attn \
--host 0.0.0.0 \
--jinja \
--port 11434I had GLM-4.5 write a poem for you: |
|
No problem and I can confirm it's running as expected for me now too (~6.5 tokens/s generation). I'm managed to transplant the vocab into so assuming it trains OK, then we should have a draft model in a day or so. It actually looks to have transplanted very well, as even the untrained draft is getting a high acceptance rate for refactoring tasks: |
|
Yesterday a bug was found for these models in vLLM and it was patched out. The PR in question is this one: vllm-project/vllm#22203 Does anyone know if this implementation is using float32 data for the self.gate module? Because if not, it might need a similar fix. |
isn't this related? |
That's just for |
I am not sure how to see the difference. Should the perplexity change? I tried following fix, but the perplexity stays the same |
Just checked and it's the router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
router_logits, _ = self.gate(hidden_states)which I think is always kept as |
# Conditions should closely match those in llama_model_quantize_internal in llama.cpp
# Some tensor types are always in float32
if data_qtype is False and (
any(
self.match_model_tensor_name(new_name, key, bid)
for key in (
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SHORTCONV_CONV,
gguf.MODEL_TENSOR.TIME_MIX_FIRST,
gguf.MODEL_TENSOR.TIME_MIX_W1,
gguf.MODEL_TENSOR.TIME_MIX_W2,
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
gguf.MODEL_TENSOR.POSNET_NORM1,
gguf.MODEL_TENSOR.POSNET_NORM2,
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
)
)
or not new_name.endswith(".weight")
):
data_qtype = gguf.GGMLQuantizationType.F32 // do not quantize expert gating tensors
// NOTE: can't use LLM_TN here because the layer number is not known
quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;then IIRC, in the backends any time a |
|
@jukofyork Thanks for checking, then all is good. |
|
https://huggingface.co/jukofyork/GLM-4.5-DRAFT-0.6B-v3.0 This should hopefully also work on |
|
@jukofyork I haven't tested with this model, but shouldn't this work without a specially crafted model since universal assisted decoding was merged? Edit: Requires a fair amount of |
Yeah, I tested the new universal assisted decoding with
I didn't play with this setting though, so agree this might improve things. |
I think I found where the bug of this problem resides. Please take a look at my post on #15112. It's very possible that it is an invalid CUDA graph update. If compiled with |
|
Hi, LLM noob here. How do you actually run the GLM models in llama.cpp (what's the command)? I tried looking for the gguf file in ggml-org/ but can't find it |
You can use any repo or file that is in GGUF format from huggingface, it does not have to be from ggml-org/. So, just search for "GLM-4.5-GGUF" (or 4.6), and sort by most popular to see which files people are using. |


Add support for the newly released GLM 4.5 family of models.
Core Architecture
Model Loading (src/llama-model.cpp)
Conversion Support (convert_hf_to_gguf.py)
Technical Details
MoE Architecture
Model Variants
The NextN/MTP prediction tensors are preserved during conversion but marked as unused since llama.cpp does not yet support multi-token prediction.
Testing
CI scripts run locally (CPU only) have two failing tests that I believe are unrelated to this change (please tell me if this isn't the case!):
gguf-dump
Disclaimer:
Hopefully resolves #14921