Skip to content

model: Add support for GLM 4.5 family of models (#14921)#14939

Merged
CISC merged 13 commits intoggml-org:masterfrom
sammcj:glm-4-5
Aug 4, 2025
Merged

model: Add support for GLM 4.5 family of models (#14921)#14939
CISC merged 13 commits intoggml-org:masterfrom
sammcj:glm-4-5

Conversation

@sammcj
Copy link
Contributor

@sammcj sammcj commented Jul 29, 2025

Add support for the newly released GLM 4.5 family of models.

Core Architecture

  • Architecture Registration: Added LLM_ARCH_GLM4_MOE enum and architecture mappings
  • Tensor Definitions: Complete tensor mappings for MoE components including 128 routed experts + 1 shared expert
  • Hybrid Layer Support: Added n_layer_dense_lead parameter to handle different dense/MoE layer patterns between variants

Model Loading (src/llama-model.cpp)

  • Multi-variant Support: Automatic detection and loading for both 47-layer (Air) and 93-layer (full) models
  • MoE Infrastructure: Complete expert weight loading with merged 3D tensor format
  • Graph Implementation: New llm_build_glm4_moe class with sigmoid-based expert routing and top-8 selection
  • Shared Expert Integration: Proper handling of shared expert computation alongside routed experts

Conversion Support (convert_hf_to_gguf.py)

  • HuggingFace Integration: Complete Glm4MoeModel converter class
  • Expert Tensor Merging: Sophisticated logic to merge expert weights into GGUF 3D tensor format
  • Metadata Handling: Proper extraction and conversion of MoE parameters from HuggingFace config

Technical Details

MoE Architecture

  • Expert Count: 128 routed experts + 1 shared expert per MoE layer
  • Expert Selection: Top-8 experts per token with sigmoid-based routing (not softmax)
  • Hybrid Layers: Dense layer for layer 0, MoE for remaining layers
  • Weight Format: Expert weights stored as merged [num_experts, hidden_size, ffn_size] tensors

Model Variants

  • GLM-4.5: 355B total parameters, 32B active, 93 layers, includes K/Q norm tensors
  • GLM-4.5-Air: 106B total parameters, 12B active, 47 layers, no K/Q norm tensors

The NextN/MTP prediction tensors are preserved during conversion but marked as unused since llama.cpp does not yet support multi-token prediction.

Testing

  • Builds successfully with no compilation errors.
  • convert_hf_to_gguf.py working.
  • llama-quantize working.

CI scripts run locally (CPU only) have two failing tests that I believe are unrelated to this change (please tell me if this isn't the case!):

94% tests passed, 2 tests failed out of 35

Label Time Summary:
main    = 251.60 sec*proc (35 tests)

Total Test time (real) = 251.61 sec

The following tests FAILED:
	 14 - test-tokenizers-ggml-vocabs (Failed)
	 27 - test-thread-safety (Subprocess aborted)


Analysis of Test Failures

1. test-tokenizers-ggml-vocabs - Corrupted Test Files

gguf_init_from_file_impl: invalid magic characters: 'vers', expected 'GGUF'
- Issue: Corrupted GGUF vocabulary files (ggml-vocab-nomic-bert-moe.gguf, etc.)
- Cause: File corruption in test environment, not code changes
- Relation to GLM 4.5: None - this is about vocabulary files, not architecture definitions

2. test-thread-safety - CUDA Environment Issues

CUDA error: unspecified launch failure
current device: 1, in function ggml_backend_cuda_synchronize
- Issue: CUDA backend threading/synchronisation failure
- Cause: CUDA driver/environment issues in CI system
- Relation to GLM 4.5: None - our changes were all CPU-side model loading logic

gguf-dump
```plain

TODO when ready


```

Disclaimer:

  • I am certainly not an expert in this - I think this is my first attempt at contributing a new model architecture to llama.cpp.
  • The most useful feedback is the code changes to make.
  • I did leverage the smarts of AI to help with the changes.
  • If this is not up to standard or I am completely off track, please feel free to reject this PR, I totally understand if someone smarter than I could do a better job of it.

Hopefully resolves #14921

@github-actions github-actions bot added the python python script changes label Jul 29, 2025
@CISC
Copy link
Collaborator

CISC commented Jul 29, 2025

Just a few quick notes from a glance:

  • Please name it GLM4_MOE, not GLM45
  • There's already LLM_KV_LEADING_DENSE_BLOCK_COUNT, no need for LLM_KV_FIRST_K_DENSE_REPLACE
  • Use GGML_ASSERT instead of throwing
  • Be mindful of whitespaces and alignments

Will do a proper review when you are ready. :)

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

Hey @CISC no worries on the naming etc.. will do.
Whitespace changes will be fixed, I haven't run this through linting yet, will get back to this later tonight hopefully.

@AnneKitsune
Copy link

FYI when trying to run convert_hf_to_gguf.py on GLM4.5-Air-FP8, I get that some constants ending with _EXPS don't exist. If I replace these by _EXP, then I get a different error related to matrix mapping.
Thank you for working on this!

@CISC
Copy link
Collaborator

CISC commented Jul 29, 2025

FYI when trying to run convert_hf_to_gguf.py on GLM4.5-Air-FP8, I get that some constants ending with _EXPS don't exist. If I replace these by _EXP, then I get a different error related to matrix mapping. Thank you for working on this!

That's because converting FP8 weights isn't supported yet, see #14810

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

I'm close to having convert_hf_to_gguf.py and llama-quantize working (see updated PR), it completes conversion without error and I was then able to quantise to Q4_K_M.

gguf-dump worked, but llama-server picked up a tensor mapping issue with token_embd.weight, so I've just put a fix into convert_hf_to_gguf.py.

I'm going through the whole conversion then quantisation process again, it's getting late here (Hi from Melbourne 👋), so I'll come back and see if it's finished in 20~.

@pwilkin
Copy link
Collaborator

pwilkin commented Jul 29, 2025

The LLM_TYPE code is wrong, those models aren't (respectively) dense 12B and 32B models. You have to add new MoE constants for them (see Qwen3 and Ernie MoEs as examples).

@pwilkin
Copy link
Collaborator

pwilkin commented Jul 29, 2025

Also, you might want to include the nextn tensors instead of throwing them out - MTP support is not there yet, but that way you won't have to reconvert and requantize if/when it arrives.

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

Thanks @pwilkin, LLM_TYPE updated.

I've added the nextn tensors into the conversion, skipping mapping to avoid errors.

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

Note that preserving the nextn tensors does result in a larger GGUF (780 tensors -> 1184 & 214GB -> 221GB for the f16)

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

I can't replicate that error @Thireus

@pwilkin
Copy link
Collaborator

pwilkin commented Jul 29, 2025

Note that preserving the nextn tensors does result in a larger GGUF (780 tensors -> 1184 & 214GB -> 221GB for the f16)

Obviously, but they won't get loaded since they're not supported 😄

Also, don't make my mistake:
"torch_dtype": "bfloat16"

Don't convert to f16, do --outtype bf16 or your model will probably have errors in the tensors.

@CISC
Copy link
Collaborator

CISC commented Jul 29, 2025

If you add unused tensors to the GGUF you must mark those tensors as unused (GGML_OP_NONE) in llama-arch.cpp, otherwise you will get an error when loading the model!

Just FYI, all other models with MTP so far have those tensors stripped.

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

If you add unused tensors to the GGUF you must mark those tensors as unused (GGML_OP_NONE) in llama-arch.cpp, otherwise you will get an error when loading the model!

Ah, that'd explain why I'm getting llama_model_load: error loading model: done_getting_tensors: wrong number of tensors; expected 1184, got 735! - I'll push a change for that shortly @CISC.

I'll have to come back to this in the morning as it's getting late here.

If anyone is keen for this ASAP and has improvements feel free to either raise a PR against my branch or pull my commits into a PR of your own if you have a better approach and I'll review in the morning.

@CISC
Copy link
Collaborator

CISC commented Jul 29, 2025

I'll just put it out there right now; no-one should make GGUFs from this PR public yet, there will be changes! :)

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

I'll just put it out there right now; no-one should make GGUFs from this PR public yet, there will be changes! :)

Absolutely, I hope people do not do that - it's very much in draft and I'm learning as I go.

@Thireus
Copy link
Contributor

Thireus commented Jul 29, 2025

@sammcj, 7f026fb#diff-4f653096980bd7d10518aa909cb648452cd3aa380ff93cb9fb642dca48536526 fixed the issue thanks.

@ricyoung
Copy link

the fix seems to work, still testing -> INFO:hf-to-gguf:Model successfully exported to models/glm-45-air-f16.gguf

@Thireus

This comment was marked as off-topic.

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

@Thireus are you sure that GGUF conversion is complete and correct? By uploading potentially broken conversions to HF you could be causing a lot of people grief and wasted bandwidth.

@Thireus

This comment was marked as off-topic.

@anikifoss
Copy link

anikifoss commented Jul 29, 2025

convert_hf_to_gguf.py works, but got this error trying to quantize to Q8_0:

./build/bin/llama-quantize \
    /mnt/data/Models/zai-org/GLM-4.5-Air-GGUF/GLM-4.5-Air-128x9.4B-BF16-00001-of-00005.gguf \
    /mnt/data/Models/anikifoss/GLM-4.5-Air/GLM-4.5-Air.gguf \
    Q8_0 \
    32

...

llama-glm/src/llama-quant.cpp:717: GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected") failed
[New LWP 199420]

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

Folks, this PR is in DRAFT it is not expected to work yet, when it does I will move it out of draft. Until then unless you have a code change to recommend, please hold other comments until its ready.

@anikifoss
Copy link

Folks, this PR is in DRAFT it is not expected to work yet, when it does I will move it out of draft. Until then unless you have a code change to recommend, please hold other comments until its ready.

Sorry, sometimes people want help testing, so I chimed in.

@sammcj
Copy link
Contributor Author

sammcj commented Jul 29, 2025

No worries at all, sorry if my message came across angry (it wasn't intended to), I just want to make sure that folks don't end up wasting their time or bandwidth, or worse - blame the llama.cpp project when it's my ability that's lacking.

The best help if anyone wants to add it right now would be in code corrections (either point out the required change here or in a PR to my branch @'ing me so I get a notification.

Thanks for wanting to help out, I promise as soon as I think it's in a state that it could be tested I will update this thread.

@anikifoss
Copy link

No pressure. We're just all excited to try the new model. Thanks for doing the hard work!

@sammcj
Copy link
Contributor Author

sammcj commented Jul 30, 2025

Bingo! I have conversion, quantisation and llama-server working! 🎉

./bin/llama-server -m /Users/samm/LLM\ Models/zai-org_GLM-4.5-Air/glm-4.5-air-q3_K_M.gguf -n 10 --temp 0.1
build: 4624 (2e54c5125) with Apple clang version 17.0.0 (clang-1700.0.13.5) for arm64-apple-darwin24.5.0
system info: n_threads = 8, n_threads_batch = 8, total_threads = 12

system_info: n_threads = 8 (n_threads_batch = 8) / 12 | Metal : EMBED_LIBRARY = 1 | BF16 = 1 | CPU : NEON = 1 | ARM_FMA = 1 | FP16_VA = 1 | MATMUL_INT8 = 1 | DOTPROD = 1 | ACCELERATE = 1 | REPACK = 1 |

main: binding port with default address family
main: HTTP server is listening, hostname: 127.0.0.1, port: 8080, http threads: 11
main: loading model
srv    load_model: loading model '/Users/samm/LLM Models/zai-org_GLM-4.5-Air/glm-4.5-air-q3_K_M.gguf'
llama_model_load_from_file_impl: using device Metal (Apple M2 Max) - 83999 MiB free
llama_model_loader: loaded meta data with 39 key-value pairs and 803 tensors from /Users/samm/LLM Models/zai-org_GLM-4.5-Air/glm-4.5-air-q3_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = glm4moe
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Zai org_GLM 4.5 Air
llama_model_loader: - kv   3:                         general.size_label str              = 128x9.4B
llama_model_loader: - kv   4:                            general.license str              = mit
llama_model_loader: - kv   5:                               general.tags arr[str,1]       = ["text-generation"]
llama_model_loader: - kv   6:                          general.languages arr[str,2]       = ["en", "zh"]
llama_model_loader: - kv   7:                        glm4moe.block_count u32              = 47
llama_model_loader: - kv   8:                     glm4moe.context_length u32              = 131072
llama_model_loader: - kv   9:                   glm4moe.embedding_length u32              = 4096
llama_model_loader: - kv  10:                glm4moe.feed_forward_length u32              = 10944
llama_model_loader: - kv  11:               glm4moe.attention.head_count u32              = 96
llama_model_loader: - kv  12:            glm4moe.attention.head_count_kv u32              = 8
llama_model_loader: - kv  13:                     glm4moe.rope.freq_base f32              = 1000000.000000
llama_model_loader: - kv  14:   glm4moe.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  15:                  glm4moe.expert_used_count u32              = 8
llama_model_loader: - kv  16:               glm4moe.attention.key_length u32              = 128
llama_model_loader: - kv  17:             glm4moe.attention.value_length u32              = 128
llama_model_loader: - kv  18:               glm4moe.rope.dimension_count u32              = 64
llama_model_loader: - kv  19:                       glm4moe.expert_count u32              = 128
llama_model_loader: - kv  20:         glm4moe.expert_feed_forward_length u32              = 1408
llama_model_loader: - kv  21:                glm4moe.expert_shared_count u32              = 1
llama_model_loader: - kv  22:          glm4moe.leading_dense_block_count u32              = 1
llama_model_loader: - kv  23:                 glm4moe.expert_gating_func u32              = 2
llama_model_loader: - kv  24:               glm4moe.expert_weights_scale f32              = 1.000000
llama_model_loader: - kv  25:                glm4moe.expert_weights_norm bool             = true
llama_model_loader: - kv  26:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  27:                         tokenizer.ggml.pre str              = glm4
llama_model_loader: - kv  28:                      tokenizer.ggml.tokens arr[str,151552]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  29:                  tokenizer.ggml.token_type arr[i32,151552]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  30:                      tokenizer.ggml.merges arr[str,318088]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  31:                tokenizer.ggml.eos_token_id u32              = 151329
llama_model_loader: - kv  32:            tokenizer.ggml.padding_token_id u32              = 151329
llama_model_loader: - kv  33:                tokenizer.ggml.eot_token_id u32              = 151336
llama_model_loader: - kv  34:            tokenizer.ggml.unknown_token_id u32              = 151329
llama_model_loader: - kv  35:                tokenizer.ggml.bos_token_id u32              = 151329
llama_model_loader: - kv  36:                    tokenizer.chat_template str              = [gMASK]<sop>\n{%- if tools -%}\n<|syste...
llama_model_loader: - kv  37:               general.quantization_version u32              = 2
llama_model_loader: - kv  38:                          general.file_type u32              = 12
llama_model_loader: - type  f32:  334 tensors
llama_model_loader: - type q5_0:   90 tensors
llama_model_loader: - type q5_1:    3 tensors
llama_model_loader: - type q3_K:  281 tensors
llama_model_loader: - type q4_K:   92 tensors
llama_model_loader: - type q5_K:    2 tensors
llama_model_loader: - type q6_K:    1 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q3_K - Medium
print_info: file size   = 57.35 GiB (4.46 BPW)
load: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect
load: special tokens cache size = 36
load: token to piece cache size = 0.9713 MB
print_info: arch             = glm4moe
print_info: vocab_only       = 0
print_info: n_ctx_train      = 131072
print_info: n_embd           = 4096
print_info: n_layer          = 47
print_info: n_head           = 96
print_info: n_head_kv        = 8
print_info: n_rot            = 64
print_info: n_swa            = 0
print_info: is_swa_any       = 0
print_info: n_embd_head_k    = 128
print_info: n_embd_head_v    = 128
print_info: n_gqa            = 12
print_info: n_embd_k_gqa     = 1024
print_info: n_embd_v_gqa     = 1024
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-05
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 10944
print_info: n_expert         = 128
print_info: n_expert_used    = 8
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 0
print_info: rope scaling     = linear
print_info: freq_base_train  = 1000000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 131072
print_info: rope_finetuned   = unknown
print_info: model type       = 106B.A12B
print_info: model params     = 110.47 B
print_info: general.name     = Zai org_GLM 4.5 Air
print_info: vocab type       = BPE
print_info: n_vocab          = 151552
print_info: n_merges         = 318088
print_info: BOS token        = 151329 '<|endoftext|>'
print_info: EOS token        = 151329 '<|endoftext|>'
print_info: EOT token        = 151336 '<|user|>'
print_info: UNK token        = 151329 '<|endoftext|>'
print_info: PAD token        = 151329 '<|endoftext|>'
print_info: LF token         = 198 'Ċ'
print_info: EOG token        = 151329 '<|endoftext|>'
print_info: EOG token        = 151336 '<|user|>'
print_info: max token length = 1024
load_tensors: loading model tensors, this can take a while... (mmap = true)
model has unused tensor blk.46.eh_proj (size = 134217728 bytes) -- ignoring
model has unused tensor blk.46.embed_tokens (size = 2483027968 bytes) -- ignoring
model has unused tensor blk.46.enorm (size = 16384 bytes) -- ignoring
model has unused tensor blk.46.hnorm (size = 16384 bytes) -- ignoring
model has unused tensor blk.46.shared_head.head (size = 2483027968 bytes) -- ignoring
model has unused tensor blk.46.shared_head.norm (size = 16384 bytes) -- ignoring

load_tensors: offloading 47 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 48/48 layers to GPU
load_tensors: Metal_Mapped model buffer size = 56356.51 MiB
load_tensors:   CPU_Mapped model buffer size =   254.38 MiB
....................................................................................................
llama_context: constructing llama_context
llama_context: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4096
llama_context: n_ctx_per_seq = 4096
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 0
llama_context: kv_unified    = true
llama_context: freq_base     = 1000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (4096) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M2 Max
ggml_metal_init: picking default device: Apple M2 Max
ggml_metal_load_library: using embedded metal library
ggml_metal_init: GPU name:   Apple M2 Max
ggml_metal_init: GPU family: MTLGPUFamilyApple8  (1008)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)
ggml_metal_init: simdgroup reduction   = true
ggml_metal_init: simdgroup matrix mul. = true
ggml_metal_init: has residency sets    = true
ggml_metal_init: has bfloat            = true
ggml_metal_init: use bfloat            = true
ggml_metal_init: hasUnifiedMemory      = true
ggml_metal_init: recommendedMaxWorkingSetSize  = 88080.38 MB
llama_context:        CPU  output buffer size =     0.58 MiB
llama_kv_cache_unified:      Metal KV buffer size =   752.00 MiB
llama_kv_cache_unified: size =  752.00 MiB (  4096 cells,  47 layers,  1/ 1 seqs), K (f16):  376.00 MiB, V (f16):  376.00 MiB
llama_kv_cache_unified: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility
llama_context:      Metal compute buffer size =   848.00 MiB
llama_context:        CPU compute buffer size =    16.01 MiB
llama_context: graph nodes  = 3358
llama_context: graph splits = 2
common_init_from_params: added <|endoftext|> logit bias = -inf
common_init_from_params: added <|user|> logit bias = -inf
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
common_chat_templates_init: failed to parse chat template (defaulting to chatml): Expected comma in tuple at row 47, column 102:
{{ visible_text(m.content) }}
{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}
                                                                                                     ^
{%- elif m.role == 'assistant' -%}

srv          init: initializing slots, n_slots = 1
slot         init: id  0 | task -1 | new slot n_ctx_slot = 4096
main: model loaded
main: chat template, chat_template: {%- for message in messages -%}
  {{- '<|im_start|>' + message.role + '
' + message.content + '<|im_end|>
' -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
  {{- '<|im_start|>assistant
' -}}
{%- endif -%}, example_format: '<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there<|im_end|>
<|im_start|>user
How are you?<|im_end|>
<|im_start|>assistant
'
main: server is listening on http://127.0.0.1:8080 - starting the main loop
srv  update_slots: all slots are idle
srv  log_server_r: request: GET / 127.0.0.1 200
srv  log_server_r: request: GET /favicon.ico 127.0.0.1 404
srv  log_server_r: request: GET /props 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 4096, n_keep = 0, n_prompt_tokens = 26
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 26, n_tokens = 26, progress = 1.000000
slot update_slots: id  0 | task 0 | prompt done, n_past = 26, n_tokens = 26
slot      release: id  0 | task 0 | stop processing: n_past = 35, truncated = 0
slot print_timing: id  0 | task 0 |
prompt eval time =     478.73 ms /    26 tokens (   18.41 ms per token,    54.31 tokens per second)
       eval time =     330.12 ms /    10 tokens (   33.01 ms per token,    30.29 tokens per second)
      total time =     808.85 ms /    36 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
srv  params_from_: Chat format: Content-only
slot launch_slot_: id  0 | task 11 | n_predict = 128 exceeds server configuration, setting to 10
slot launch_slot_: id  0 | task 11 | processing task
slot update_slots: id  0 | task 11 | new prompt, n_ctx_slot = 4096, n_keep = 0, n_prompt_tokens = 48
slot update_slots: id  0 | task 11 | kv cache rm [6, end)
slot update_slots: id  0 | task 11 | prompt processing progress, n_past = 48, n_tokens = 42, progress = 0.875000
slot update_slots: id  0 | task 11 | prompt done, n_past = 48, n_tokens = 42
slot      release: id  0 | task 11 | stop processing: n_past = 57, truncated = 0
slot print_timing: id  0 | task 11 |
prompt eval time =     789.82 ms /    42 tokens (   18.81 ms per token,    53.18 tokens per second)
       eval time =     332.72 ms /    10 tokens (   33.27 ms per token,    30.06 tokens per second)
      total time =    1122.55 ms /    52 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200
curl -X POST http://localhost:8080/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
    "messages": [
      {"role": "system", "content": "You are a helpful assistant."},
      {"role": "user", "content": "Tell me a joke about penguins."}
    ],
    "max_tokens": 128
  }'
{"choices":[{"finish_reason":"length","index":0,"message":{"role":"assistant","content":"Why don't penguins fly?\n\nBecause they're"}}],"created":1753833691,"model":"gpt-3.5-turbo","system_fingerprint":"b4624-2e54c5125","object":"chat.completion","usage":{"completion_tokens":10,"prompt_tokens":48,"total_tokens":58},"id":"chatcmpl-t533hXZP8BopY7Hh3jAr5EgPL2ihikk3","timings":{"prompt_n":42,"prompt_ms":789.824,"prompt_per_token_ms":18.805333333333333,"prompt_per_second":53.17640385706183,"predicted_n":10,"predicted_ms":332.722,"predicted_per_token_ms":33.2722,"predicted_per_second":30.05512109208288}}%

Please feel free to test this out and if you have code changes to suggest - please do those here.

Note: If you end up sharing any GGUFs built from this PR - PLEASE make it clear that they're built from a llama.cpp PR (aka an unofficial fork) and that there may be changes before it's stable.

@sammcj sammcj marked this pull request as ready for review July 30, 2025 00:30
@jukofyork
Copy link
Collaborator

@jukofyork image

Thanks!

@AesSedai
Copy link
Contributor

AesSedai commented Aug 5, 2025

@CISC I narrowed down the gibberish issue a bit. It requires setting --batch-size 4096 --ubatch-size 4096 and possibly having a long multi-turn chat going. When I removed the batch-size / ubatch-size, my 40k and 50k token chats began working again. Setting the sizes up to 2048 / 2048 also worked. Something about 4096 / 4096 combined with over 32k context across multiple turns leads to that gibberish edge case.

I also tried a needle in a haystack test with a 35k token prompt with a direction to answer a question from the text as a one-shot and that worked. So I don't have a reproducible smoking gun, but batch-size / ubatch-size is involved and for now I'm just scaling them back to make it work.

@CISC
Copy link
Collaborator

CISC commented Aug 5, 2025

I also tried a needle in a haystack test with a 35k token prompt with a direction to answer a question from the text as a one-shot and that worked. So I don't have a reproducible smoking gun, but batch-size / ubatch-size is involved and for now I'm just scaling them back to make it work.

Ah, ok, so that means it's not a model issue then, that's great!

Submit an issue though. :)

@CISC
Copy link
Collaborator

CISC commented Aug 5, 2025

Just FYI for anyone wanting to create i-quants; as the final layer will not get imatrix data until MTP is supported it has to be overridden for lower quants to work, eg. using --tensor-type 46=iq4_xs or --tensor-type 92=iq4_xs.

cc/ @bartowski1182 @danielhanchen @nicoboss

@jacekpoplawski
Copy link
Contributor

I am getting over 45t/s on three 3090s on unsloth quant Q4 for GLM Air, here is the optimized command:

llama-server -ts 18/17/18 -ngl 99 -m ~/models/GLM-4.5-Air-UD-Q4_K_XL-00001-of-00002.gguf --n-cpu-moe 2 --jinja --host 0.0.0.0

@jukofyork
Copy link
Collaborator

1. It still seems to be skipping warmup. It's loading the model into system RAM **after** receiving the first prompt.

I can confirm it's not warming up.

Manually setting --override-kv glm4moe.expert_used_count=int:160 to try to get it to warm up triggers:

ggml_new_object: not enough space in the context's memory pool (needed 5730848, available 5730480)

If I patch src/llama-context.cpp:

uint32_t llama_context::graph_max_nodes() const {
    //return std::max<uint32_t>(1024u, 8u*model.n_tensors());
    return std::max<uint32_t>(65536u, 8u*model.n_tensors());
} 

and then run with --override-kv glm4moe.expert_used_count=int:160 it warms up fine.

You then need to rerun without --override-kv glm4moe.expert_used_count=int:160.

I've got to go out so no more time to investigate until later.

@jukofyork
Copy link
Collaborator

Actually, no it's still not warming up properly - it's just a lot quicker because it's got the experts mmapped I think... Will see if I can figure it out later if nobody else has by then.

@jukofyork
Copy link
Collaborator

I've found it:

                // MoE layer with shared experts
                //const int64_t n_expert      = hparams.n_expert;
                //const int64_t n_expert_used = hparams.n_expert_used;

                // Process routed experts using existing MoE infrastructure
                ggml_tensor * routed_out = build_moe_ffn(cur,
                        model.layers[il].ffn_gate_inp,
                        model.layers[il].ffn_up_exps,
                        model.layers[il].ffn_gate_exps,
                        model.layers[il].ffn_down_exps,
                        model.layers[il].ffn_exp_probs_b,
                        n_expert, n_expert_used,
                        LLM_FFN_SILU, hparams.expert_weights_norm,
                        true, hparams.expert_weights_scale,
                        (llama_expert_gating_func_type) hparams.expert_gating_func,
                        il);
                cb(routed_out, "ffn_moe_out", il);

The local n_expert and n_expert_used were shadowing those set here:

llm_graph_context::llm_graph_context(const llm_graph_params & params) :
    arch             (params.arch),
    hparams          (params.hparams),
    cparams          (params.cparams),
    ubatch           (params.ubatch),
    n_embd           (hparams.n_embd),
    n_layer          (hparams.n_layer),
    n_rot            (hparams.n_rot),
    n_ctx            (cparams.n_ctx),
    n_head           (hparams.n_head()),
    n_head_kv        (hparams.n_head_kv()),
    n_embd_head_k    (hparams.n_embd_head_k),
    n_embd_k_gqa     (hparams.n_embd_k_gqa()),
    n_embd_head_v    (hparams.n_embd_head_v),
    n_embd_v_gqa     (hparams.n_embd_v_gqa()),
    n_expert         (hparams.n_expert),
    n_expert_used    (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
    freq_base        (cparams.rope_freq_base),
    freq_scale       (cparams.rope_freq_scale),
    ext_factor       (cparams.yarn_ext_factor),
    attn_factor      (cparams.yarn_attn_factor),
    beta_fast        (cparams.yarn_beta_fast),
    beta_slow        (cparams.yarn_beta_slow),
    norm_eps         (hparams.f_norm_eps),
    norm_rms_eps     (hparams.f_norm_rms_eps),
    n_tokens         (ubatch.n_tokens),
    n_outputs        (params.n_outputs),
    n_ctx_orig       (cparams.n_ctx_orig_yarn),
    pooling_type     (cparams.pooling_type),
    rope_type        (hparams.rope_type),
    sched            (params.sched),
    backend_cpu      (params.backend_cpu),
    cvec             (params.cvec),
    loras            (params.loras),
    mctx             (params.mctx),
    cross            (params.cross),
    cb_func          (params.cb),
    res              (params.res),
    ctx0             (res->get_ctx()),
    gf               (res->get_gf()) {
        res->set_params(params);
    }

@jukofyork
Copy link
Collaborator

#15088

@createthis
Copy link
Contributor

createthis commented Aug 5, 2025

@jukofyork confirmed. This fixes warmup for me. It also restores the GLM-4.5 to the performance levels I've come to expect from llama.cpp:

Screenshot 2025-08-05 at 9 14 14 AM

Startup command:

Details
./build/bin/llama-server \
    --model /data/GLM-4.5-GGUF/q4_k_m/GLM-4.5-Q4_K_M.gguf \
    --alias GLM-4.5-GGUF:q4_k_m \
    --no-webui \
    --numa numactl \
    --threads 32 \
    --ctx-size 131072 \
    --n-gpu-layers 94 \
    -ot "blk\.(3|4|5|6|7|8|9|10|11|12|13|14|15|16|17)\.ffn_.*=CUDA0" \
    -ot exps=CPU \
    -ub 4096 -b 4096 \
    --seed 3407 \
    --temp 0.6 \
    --top-p 1.0 \
    --log-colors \
    --flash-attn \
    --host 0.0.0.0 \
    --jinja \
    --port 11434

I had GLM-4.5 write a poem for you:

Jukofyork, with skillful hand,
Commit c81de6e fixed the land.
GLM-4.5 warmup, once so slow,
Now performs with steady glow.
Removed those lines that caused the pain,
Llama.cpp runs fast again.

@jukofyork
Copy link
Collaborator

No problem and I can confirm it's running as expected for me now too (~6.5 tokens/s generation).

I'm managed to transplant the vocab into Qwen2.5-Coder-0.5B-Instruct:

Loading config from 'Qwen2.5-Coder-0.5B-Instruct'... Done.
Loading config from 'GLM-4.5'... Done.
Loading tokenizer from 'Qwen2.5-Coder-0.5B-Instruct'... Done.
Loading tokenizer from 'GLM-4.5'... Done.
Loading model from 'Qwen2.5-Coder-0.5B-Instruct'... Done.

Input model configuration:
- Target vocabulary size    : 151552 (used = 151365, unused = 187)
- Donor vocabulary size     : 151936
- Donor num layers          : 24 (tied embeddings = True)
- Donor hidden size         : 896
- Donor attention heads     : 14
- Donor intermediate size   : 4864 (ratio = 1:5.4)
- Donor total parameters    : 494032768 (0.49B)
-- Embedding parameters     : 136134656 (0.14B)
-- Non-embedding parameters : 357898112 (0.36B)

Processing 3 automatic token overrides:
✘ 'bos_token_id' : Not found for target model
✔ 'eos_token_id' : 151329 '<|endoftext|>' → [151645] '<|im_end|>'
✘ 'pad_token_id' : 151329 is already mapped to [151645]

Processing 14 manual token overrides:
✔ 151329 : '<|endoftext|>' → [151643] '<|endoftext|>'
✔ 151330 : '[MASK]' → [151643] '<|endoftext|>'
✔ 151331 : '[gMASK]' → [151643] '<|endoftext|>'
✔ 151332 : '[sMASK]' → [151643] '<|endoftext|>'
✔ 151333 : '<sop>' → [151643] '<|endoftext|>'
✔ 151334 : '<eop>' → [151643] '<|endoftext|>'
✔ 151335 : '<|system|>' → [151644, 8948] '<|im_start|>system'
✔ 151336 : '<|user|>' → [151644, 872] '<|im_start|>user'
✔ 151337 : '<|assistant|>' → [151644, 77091] '<|im_start|>assistant'
✔ 151338 : '<|observation|>' → [151644, 872] '<|im_start|>user'
✔ 151352 : '<tool_call>' → [151657] '<tool_call>'
✔ 151353 : '</tool_call>' → [151658] '</tool_call>'
✔ 151354 : '<tool_response>' → [151657] '<tool_call>'
✔ 151355 : '</tool_response>' → [151658] '</tool_call>'

NOTE: Using an "untied" copy of 'embed_tokens.weight' as new 'lm_head.weight' tensor...

Transplanting tokens: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 151365/151365 [00:42<00:00, 3558.20token/s]

Transplant mappings:
- 1 to 1  : 123102 (81%)
- 2 to 1  : 23944 (16%)
- 3 to 1  : 3262 (2.2%)
- 4 to 1  : 821 (0.54%)
- 5 to 1  : 181 (0.12%)
- 6 to 1  : 26 (0.017%)
- 7 to 1  : 21 (0.014%)
- 8 to 1  : 5 (0.0033%)
- 9 to 1  : 1 (0.00066%)
- 13 to 1 : 1 (0.00066%)
- 16 to 1 : 1 (0.00066%)

Head initialized with:
- Copies : 123102 (81%)
- Means  : 28263 (19%)
- Zeros  : 187 (0.12%)

Output model configuration:
- Output vocabulary size    : 151552
- Output num layers         : 24 (tied embeddings = False)
- Output hidden size        : 896
- Output attention heads    : 14
- Output intermediate size  : 4864 (ratio = 1:5.4)
- Output total parameters   : 629479296 (0.63B)
-- Embedding parameters     : 271581184 (0.27B)
-- Non-embedding parameters : 357898112 (0.36B)

Saving model and tokenizer to 'GLM-4.5-DRAFT-0.6B-UNTRAINED' folder

so assuming it trains OK, then we should have a draft model in a day or so.

It actually looks to have transplanted very well, as even the untrained draft is getting a high acceptance rate for refactoring tasks:

prompt eval time =   59625.37 ms /  2339 tokens (   25.49 ms per token,    39.23 tokens per second)
       eval time =  288397.17 ms /  3170 tokens (   90.98 ms per token,    10.99 tokens per second)
      total time =  348022.54 ms /  5509 tokens
slot print_timing: id  0 | task 0 | 
draft acceptance rate = 0.74499 ( 2080 accepted /  2792 generated)

@Mushoz
Copy link

Mushoz commented Aug 6, 2025

Yesterday a bug was found for these models in vLLM and it was patched out. The PR in question is this one: vllm-project/vllm#22203

Does anyone know if this implementation is using float32 data for the self.gate module? Because if not, it might need a similar fix.

@jacekpoplawski
Copy link
Contributor

Yesterday a bug was found for these models in vLLM and it was patched out. The PR in question is this one: vllm-project/vllm#22203

Does anyone know if this implementation is using float32 data for the self.gate module? Because if not, it might need a similar fix.

isn't this related?

if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
        }

@CISC
Copy link
Collaborator

CISC commented Aug 6, 2025

isn't this related?

That's just for ffn_down though, this suggests it should be done for ffn_gate too (if it's enough to just up the precision of mul-mat), can someone test?

@jacekpoplawski
Copy link
Contributor

jacekpoplawski commented Aug 6, 2025

isn't this related?

That's just for ffn_down though, this suggests it should be done for ffn_gate too (if it's enough to just up the precision of mul-mat), can someone test?

I am not sure how to see the difference. Should the perplexity change? I tried following fix, but the perplexity stays the same

diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 053c72d6..4c101848 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -662,11 +662,19 @@ ggml_tensor * llm_graph_context::build_ffn(
             case LLM_FFN_SEQ:
                 {
                     cur = build_lora_mm(gate, tmp);
+                    if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
+                        // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
+                        ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
+                    }
                     cb(cur, "ffn_gate", il);
                 } break;
             case LLM_FFN_PAR:
                 {
                     cur = build_lora_mm(gate, cur);
+                    if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
+                        // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
+                        ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
+                    }
                     cb(cur, "ffn_gate", il);
                 } break;
         }
@@ -746,6 +754,10 @@ ggml_tensor * llm_graph_context::build_ffn(

     if (gate && type_gate == LLM_FFN_PAR) {
         cur = ggml_mul(ctx0, cur, tmp);
+        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
+            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
+            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
+        }
         cb(cur, "ffn_gate_par", il);
     }

@jukofyork
Copy link
Collaborator

isn't this related?

That's just for ffn_down though, this suggests it should be done for ffn_gate too (if it's enough to just up the precision of mul-mat), can someone test?

Just checked and it's the [hidden_dim, n_experts] router logits tensor:

        router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
        router_logits, _ = self.gate(hidden_states)

which I think is always kept as float32 for all MoE models in llama.cpp anyway, so should be fine.

@jukofyork
Copy link
Collaborator

                # Conditions should closely match those in llama_model_quantize_internal in llama.cpp
                # Some tensor types are always in float32
                if data_qtype is False and (
                    any(
                        self.match_model_tensor_name(new_name, key, bid)
                        for key in (
                            gguf.MODEL_TENSOR.FFN_GATE_INP,
                            gguf.MODEL_TENSOR.POS_EMBD,
                            gguf.MODEL_TENSOR.TOKEN_TYPES,
                            gguf.MODEL_TENSOR.SSM_CONV1D,
                            gguf.MODEL_TENSOR.SHORTCONV_CONV,
                            gguf.MODEL_TENSOR.TIME_MIX_FIRST,
                            gguf.MODEL_TENSOR.TIME_MIX_W1,
                            gguf.MODEL_TENSOR.TIME_MIX_W2,
                            gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
                            gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
                            gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
                            gguf.MODEL_TENSOR.POSNET_NORM1,
                            gguf.MODEL_TENSOR.POSNET_NORM2,
                            gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
                            gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
                            gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
                            gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
                        )
                    )
                    or not new_name.endswith(".weight")
                ):
                    data_qtype = gguf.GGMLQuantizationType.F32
        // do not quantize expert gating tensors
        // NOTE: can't use LLM_TN here because the layer number is not known
        quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;

then IIRC, in the backends any time a float32 is the left tensor in a matrix product, the other side gets promoted to float32 too.

@CISC
Copy link
Collaborator

CISC commented Aug 6, 2025

@jukofyork Thanks for checking, then all is good.

@Thireus
Copy link
Contributor

Thireus commented Aug 7, 2025

Many thanks to @sammcj, @CISC, and everyone who contributed! The code has been successfully ported and merged into ik_llama.

@jukofyork
Copy link
Collaborator

https://huggingface.co/jukofyork/GLM-4.5-DRAFT-0.6B-v3.0
https://huggingface.co/jukofyork/GLM-4.5-DRAFT-0.6B-v3.0-GGUF

This should hopefully also work on GLM-4.5-Air and GLM-4-32B-0414 as they appear to use the same tokeniser.

@CISC
Copy link
Collaborator

CISC commented Aug 8, 2025

@jukofyork I haven't tested with this model, but shouldn't this work without a specially crafted model since universal assisted decoding was merged?

Edit: Requires a fair amount of --spec-replace though.

@jukofyork
Copy link
Collaborator

@jukofyork I haven't tested with this model, but shouldn't this work without a specially crafted model since universal assisted decoding was merged?

Yeah, I tested the new universal assisted decoding with qwen2.5:0.5b and qwen3:0.6b, and it works quite well for this model (there is ~80% single token overlap) but the custom fine-tuned model still gets around 10% higher acceptance rate (at least for my coding tests).

Edit: Requires a fair amount of --spec-replace though.

I didn't play with this setting though, so agree this might improve things.

@ajunca
Copy link

ajunca commented Aug 12, 2025

@CISC I narrowed down the gibberish issue a bit. It requires setting --batch-size 4096 --ubatch-size 4096 and possibly having a long multi-turn chat going. When I removed the batch-size / ubatch-size, my 40k and 50k token chats began working again. Setting the sizes up to 2048 / 2048 also worked. Something about 4096 / 4096 combined with over 32k context across multiple turns leads to that gibberish edge case.

I also tried a needle in a haystack test with a 35k token prompt with a direction to answer a question from the text as a one-shot and that worked. So I don't have a reproducible smoking gun, but batch-size / ubatch-size is involved and for now I'm just scaling them back to make it work.

I think I found where the bug of this problem resides. Please take a look at my post on #15112. It's very possible that it is an invalid CUDA graph update. If compiled with GGML_CUDA_USE_GRAPHS=OFF then you have no more gibberish issue (GGGGGGGG...).

@ElrondL
Copy link

ElrondL commented Oct 8, 2025

Hi, LLM noob here. How do you actually run the GLM models in llama.cpp (what's the command)? I tried looking for the gguf file in ggml-org/ but can't find it

@R-Dson
Copy link
Contributor

R-Dson commented Oct 8, 2025

Hi, LLM noob here. How do you actually run the GLM models in llama.cpp (what's the command)? I tried looking for the gguf file in ggml-org/ but can't find it

You can use any repo or file that is in GGUF format from huggingface, it does not have to be from ggml-org/. So, just search for "GLM-4.5-GGUF" (or 4.6), and sort by most popular to see which files people are using.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: GLM 4.5 MoE support