Conversation
This is correct - we always alternate between conventional and speculative passes. It's definitely not optimal, but improves flexibility for regular sampling. It allows to change the speculative parameters and even disable it per request, while the logic is quite simple. It should be possible to improve this by keeping track which slots are speculating on each iteration and skip adding tokens to the conventional batch for them. It might be a good idea to implement this separately to avoid huge changes in the logic in a single PR. |
|
Generally we should try to minimize the changes to On first look, I think the path that involves minimal changes is:
Extracting the MTP logits during
Currently, I am not sure which way is better. The first requires a new API call, while the second might break some existing assumptions (not sure if that's the case yet). In any case, you can avoid this until you get the implementation working with a reasonable speedup. After that, we can discuss further how to best refactor the implementation. |
I don't see an issue with adding a new API for this, and it would be easier to use. |
|
Out of curiosity, is the API for this expected to be flexible enough that we could jump off of it to add things like Medusa / Eagle style (or IBM Accelerator) self speculative decoding heads? I'm pretty sure they work fairly similarly (depending on the final output embeddings of the current token). Another note: After some consideration I think the expected speedup of the MTP module will depend a lot on the hardware the model's running on, particularly because it's an MoE model. While the next token prediction depends only on the current state, if we're doing self speculative decoding, that's additional forward passes. Those forward passes aren't guaranteed to have the same expert usage patterns, meaning the speedup should be some function of the tokens predicted and the expert re-use coefficient for the tokens verified. So, just noting that if it's implemented and there's not a 2x or 3x increase in T/s, it may not be a skill issue on the part of a contributor, but due to the mathematical nature of the calculation. For people running franken setups with Attention / KV Cache on GPU and MoE FFNs on CPU, it's possible that using previously unused experts in the verification sweep may result in a weird situation where the parallel verification process is actually memory bandwidth bound. Not to discourage the implementation of this, I just wanted to give a heads up so nobody's dejected if the theoretical speedups can't be hit. There should still be at least some speedup, though. |
|
Thanks all for the suggestions. Will definitely look to refactor into something nicer once correctness can be established.
Yeah, I'd generally recommend that people temper their expectations with this. Especially given these three models only have one MTP head the theoretical performance gain is hard bounded by 2x on the top end, and that's assuming a perfectly efficient implementation and 100% draft acceptance. In the absence of actual data from a working prototype... I'd probably guess that the implementation after this PR will be on the order of 40% speedup, then up to 80% after completing this:
Optimistically, I hope to have an ugly but working prototype done sometime today. |
|
I've gotten to the point where I can get the MTP head to output stuff but managing KV cache with an external call to a separate MTP graph adds an unbelievable amount of complexity: I think we need to do a forward pass for the MTP layer not just when we're sampling, but for every decode token we run. This goes against the scheduling/batching that we're doing (like we'd probably have to add some form of per-token callback to Think I'll take the principled approach suggested by @ggerganov above and just create a single augmented graph. But on the plus side, from this previous attempt I'm pretty confident the MTP subgraph itself is correct, so it wasn't a total waste of time. 🤪 I'll commit the old branch in a sec in case it ever winds up being useful, but I kind of doubt it (outside of as a reference for constructing the MTP subgraph) |
|
On second thought, building a single augmented graph also doesn't work, because we need the main model's sampled token in the MTP subgraph. We could make some shortcut assumptions, like "greedy sample" in the MTP subgraph, but as soon as we fail to match the actual main model sampled token for the first time, the MTP layer's KV cache is invalid. Something along the lines of the original approach might work, management of the MTP subgraph's KV cache could be made easier by using cparams.embeddings = true and LLAMA_POOLING_TYPE_NONE, decoding an entire batch, then running the entire batch through the MTP head (discarding outputs) to keep its cache up to date. |
|
This commit sort of works, in the sense that it outputs tokens but
Still enough to make me optimistic that this general approach can be refined into a real implementation and avoids the challenges noted in my last two comments. |
|
Okay, I believe this commit "works" in that both main model and MTP output both seem correct under my informal test conditions. The model is now about as coherent as the base model is, at least in basic conversations and solving AIME problems. The typical draft acceptance rate is ~70% with my samplers. Usually gets higher than that for simple responses, code, math; lower for creative writing. Would probably be good to compare versus the vLLM implementation to see if that's around expected (I don't have enough VRAM to run either variant in vLLM sadly) This implementation is still far from done:
but despite all of that I think it works as a proof of concept. Even with the implementation as poor as it is right now it's giving like 20-25% performance uplift |
|
Tried to run it in RP scenario (using Q4 quant), got from 0.07 to 0.11 acceptance rate on swipes (one time unexpectedly got 0.18) (t=0.8, min p 0.05, top P 0.95). So yeah, probably we shouldn't expect too much from it. |
This sounds like I have a bug with token caching and/or KV shifting on my end, so far haven't been testing in any setting outside of the server webui itself. My caching behavior is almost certainly wrong outside of optimal scenarios right now. Things are still very WIP. I'll test later but I'd probably expect 50-60% draft acceptance for RP once everything is working, for reference. |
|
Upon a bit of testing on my end in RP/creative writing scenarios, I can't find any obvious issues in terms of correctness with the cache management of this prototype; I think the draft head is just unusually bad in a RP setting. I did change the weird logit-hack then standard sample workflow with a greedy sample, which I think is simpler, less hacky, and provides a slight boost to acceptance rates. But I'm still not getting much more than ~40% draft acceptance in RP scenarios with heavy use of advanced samplers like DRY and XTC. I think it'd be good to know if the vLLM implementation yields similar results. Aside from that, I'll start refactoring. The last correctness issue, which I think may also be the thorniest, is that the MTP cache updates need the embeddings of the main decode pass in order to run, but we also can only store one ubatch's worth of embeddings at once, so the KV cache update step needs to be moved into |
|
Not an expert by any means, but XTC is unlikely to work well with MTP as it's excluding the top choice(s). Have you tried it without XTC? DRY shouldn't be much of an issue, but there possibly is some impact there as well. |
|
Is work on this still progressing in the background? If not, then what kind of work still remains to be done? Is it mainly cleanup and refactoring? If so, could another developer try their hands at this? I am eagerly awaiting MTP support and it would be a shame to see no progress on this after an initial working implementation has been done. I have no experience with llama.cpp in particular, but if it's "only" cleanup work that needs no deep knowledge of ML libraries, I might be able to support. |
@SamuelOliveirads is still working on this in a PR on this branch, see F1LM1#3. It's primarily a refactor and some optimizations, but to make my crappy prototype work in a way that is reasonably maintainable/extensible is not that trivial. I've been pretty busy the last few weeks but I don't think we're too far off from a reasonable implementation |
|
That is great to hear! Thank you a lot for your effort so far! I had already guessed that it might not be possible to contribute without deeper knowledge, but was willing to give it a try anyway. It's great to hear that @SamuelOliveirads is already working on it! |
|
The latest commits successfully integrate the MTP into the The next major step is optimization. Please note: as of now, this branch is expected to be slower than the baseline without MTP. It should be considered a development preview, not a performance-ready feature. To tackle the performance issues, a lot of work is needed, particularly in areas where I'm still building my expertise. I've gathered my detailed findings and ideas for optimization in a separate discussion here: F1LM1#4 This is an open invitation for anyone with experience in the following topics to take a look and share any insights. Any help would be greatly appreciated!
|
|
I haven't made much progress on the optimization front over the last week, aside from the small graph reuse improvement mentioned in PR #4. Speaking of which, @ggerganov, I would love to get your thoughts on the strategic goals for this PR. As it stands, the MTP implementation is functionally correct and could be tested, but it's not expected to be faster than the baseline. Considering the challenges in optimizing it, do you think it's a good idea to merge this version now? It could serve as a foundation, allowing other maintainers to contribute in other areas, like further optimizations or adapting the MTP logic to other models. If that's the best approach, I can focus on adding the necessary user parameters and preparing the PR for wider testing and review. However, if you feel that improving the performance is essential before merging, I would greatly appreciate any help or information you could share regarding the current bottlenecks I've identified. |
|
I've been following this PR for a quite while and thank you for the enormous work you have done! I believe I saw on somewhere that the MTP acceptance rate is very sensitive on the quantization used. It seems standard practice is to keep the MTP layers at FP8 or even FP16 (link). But a quick look at some quants available shows they usually quantize the MTP layers using the same main quantization level. So I guess improving this may lead to better acceptance rates when using quants. |
That's a very plausible point, as vision models often suffer from the same issue, which is why most of them are kept at FP16. As for MTP, it's complicated to assess in llama.cpp since we don't fully support it yet. I'm not aware if other backends have already tested and documented how quantization affects MTP performance. I still want to thank you for the link. Looking at NVIDIA's TensorRT-LLM backend provided some inspiration on how to solve certain optimization problems. However, it's not a simple task. The ideal scenario would be to perform all MTP operations in a single model pass:
This is complicated because the MTP has a dependency on the previous hidden state. To fix this, we could apply a loop over the MTP graph to get the previous hidden state and generate one token at a time. This, in turn, creates another problem: if I try to generate 4 tokens, llama.cpp expects me to run the graph only once with a batch of 4 tokens, not 4 times with one token each. I'm currently studying the architecture and gathering some ideas, running tests to see if I can find a simpler approach. |
|
Please rebase (looks like it will be a bit of work, ignore if you already are working on it). :) |
commit 912ed2cd9339d1b2875d98744ca5b51fa62e581e
Author: samuel <samueloliveira32df@gmail.com>
Date: Sun Dec 7 23:00:29 2025 -0300
speculative (feat): implement recursive MTP drafting for GLM-4.5
commit bdf72d9
Author: samuel <samueloliveira32df@gmail.com>
Date: Sat Dec 6 16:10:16 2025 -0300
sampling (feat): optimize speculative drafting with fast-path selection
commit a91980a
Author: samuel <samueloliveira32df@gmail.com>
Date: Sat Dec 6 15:18:19 2025 -0300
mtp (chore): clean old code
commit 6de0ecf
Author: samuel <samueloliveira32df@gmail.com>
Date: Sat Dec 6 14:40:13 2025 -0300
mtp (feat): add mtp arg
commit ea77394
Author: samuel <samueloliveira32df@gmail.com>
Date: Sat Dec 6 13:47:54 2025 -0300
mtp-graph (fix): move llama_get_logits_ith outside the loop
commit 15dff20
Merge: 171346c cae85fe
Author: samuel <samueloliveira32df@gmail.com>
Date: Thu Oct 16 13:44:41 2025 -0300
Merge branch 'glm4-mtp-batch' of https://github.com/SamuelOliveirads/llama.cpp into glm4-mtp-graph-cache
commit cae85fe
Author: samuel <samueloliveira32df@gmail.com>
Date: Thu Oct 16 13:42:31 2025 -0300
mtp-batch(fix): avoid logits for mtp kv cache operations
commit 171346c
Author: samuel <samueloliveira32df@gmail.com>
Date: Sun Oct 12 16:33:01 2025 -0300
mtp-graph(feat): Reactivate graph reuse only for main model path
commit 0127c6b
Author: samuel <samueloliveira32df@gmail.com>
Date: Sat Oct 11 22:20:54 2025 -0300
mtp-batch(chore): Remove final MTP debug logs and dead code
commit 4bcc9e2
Author: samuel <samueloliveira32df@gmail.com>
Date: Sat Oct 11 18:51:22 2025 -0300
mtp-batch(fix): Correctly advance cache head and add MTP documentation
commit b4cbe03
Author: samuel <samueloliveira32df@gmail.com>
Date: Sat Oct 11 18:37:40 2025 -0300
mtp-batch(chore): Fix logit flags for speculative sampling and remove debug logs
commit a99709d
Author: samuel <samueloliveira32df@gmail.com>
Date: Fri Oct 10 17:24:34 2025 -0300
mtp-batch(refactor): Extract decode context and MTP input logic into helper methods
commit 913af8f
Author: samuel <samueloliveira32df@gmail.com>
Date: Fri Oct 10 16:44:28 2025 -0300
mtp-batch(refactor): Replace MTP boolean flags with an explicit operation enum
commit 6f74ba3
Author: samuel <samueloliveira32df@gmail.com>
Date: Thu Oct 9 22:27:18 2025 -0300
mtp-batch (fix): prevent mtp draft from polluting the cache
commit 5e1d719
Author: samuel <samueloliveira32df@gmail.com>
Date: Thu Oct 9 15:21:23 2025 -0300
mtp-batch (feat): Create and manage sinfo for MTP
commit febd823
Author: samuel <samueloliveira32df@gmail.com>
Date: Sun Oct 5 14:43:40 2025 -0300
mtp-batch (wip): fix how to warmup kv cache for MTP
commit 67c6c06
Author: samuel <samueloliveira32df@gmail.com>
Date: Sat Sep 27 19:42:32 2025 -0300
mtp-batch (wip): Isolate MTP graph to prevent host embedding buffer corruption
commit 75dc25e
Author: samuel <samueloliveira32df@gmail.com>
Date: Sat Sep 27 17:17:00 2025 -0300
mtp-batch (wip): organize batch for mtp cache
commit 3da7e7f
Author: samuel <samueloliveira32df@gmail.com>
Date: Tue Sep 23 22:45:11 2025 -0300
mtp-batch (fix): warm mtp cache for small batch size
commit df64508
Author: samuel <samueloliveira32df@gmail.com>
Date: Sun Sep 21 21:55:41 2025 -0300
mtp-batch (wip): merge glm graphs
commit 042eb8a
Author: samuel <samueloliveira32df@gmail.com>
Date: Sun Sep 21 21:29:00 2025 -0300
mtp-batch (wip): merge mtp and model graph
commit 1318b2d
Author: samuel <samueloliveira32df@gmail.com>
Date: Sun Sep 14 10:22:59 2025 -0300
mtp-batch (wip): move mtp execution to batch format
commit c6237c7
Merge: 9fab53e 8742ce0
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Sat Sep 13 02:57:01 2025 -0400
Merge pull request #1 from SamuelOliveirads/glm4-moe-mtp
feat: implemented sampling for MTP
commit 8742ce0
Author: samuel <samueloliveira32df@gmail.com>
Date: Sat Sep 6 00:21:18 2025 -0300
feat: apply logits + greedy sampler
commit 5a5bce8
Author: samuel <samueloliveira32df@gmail.com>
Date: Wed Sep 3 17:56:14 2025 -0300
fix: add sample acceptance
commit 07670a2
Author: samuel <samueloliveira32df@gmail.com>
Date: Wed Sep 3 13:25:21 2025 -0300
feat: implemented sampling for MTP
commit 9fab53e
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Tue Sep 2 17:14:09 2025 -0400
fixed mtp kv cache update step in cases where prompt size > n_batch and n_ubatch
commit 98bc0c6
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Tue Aug 26 01:26:51 2025 -0400
replace standard sampler with greedy sampler for mtp draft
commit 471e026
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Tue Aug 19 23:10:56 2025 -0400
fixed vram leak
commit d72f9d5
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Tue Aug 19 01:50:34 2025 -0400
kludge-y kv cache management of mtp layer
commit 382135a
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Sun Aug 17 21:54:45 2025 -0400
fixed mtp kv cache update sequencing after prompt processing
commit 6870f97
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Sun Aug 17 04:59:36 2025 -0400
added proper KV cache management for MTP layers and slightly refactored
commit 6e9bafc
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Fri Aug 15 23:13:56 2025 -0400
failed attempt to implement MTP; outputs tokens but KV cache management is unreasonable
commit cf0f7c0
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Wed Aug 13 02:21:17 2025 -0400
broad thrust of the mtp implementation
commit 03231da
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Tue Aug 12 01:03:59 2025 -0400
add model member function to build mtp graph, to be called from speculative.cpp
commit 1f477b3
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Mon Aug 11 20:54:45 2025 -0400
make nextn weights loadable without a crash
commit e434f87
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Mon Aug 11 01:21:47 2025 -0400
some work towards building mtp layer graph
commit db60623
Author: Aaron Lee <lee.aaron.65@gmail.com>
Date: Sun Aug 10 23:52:54 2025 -0400
added getter for nextn layer count and server slot has_mtp property
GLM-4.6 models exclude specific MTP tensors (`embed_tokens` and `shared_head_head`), implying weight tying with the main model. Previously, this caused a crash when building the graph. This commit adds a fallback mechanism to use the main model's token embeddings and output head when the MTP-specific tensors are missing.
Adds a new `mtp` boolean to `llama_model_params`. When set to false (default): 1. The loader skips loading MTP-specific tensors (NextN layers) using `TENSOR_SKIP`. 2. The KV cache size calculation excludes the MTP layer (`n_layer_kv_from_start`). This reduces VRAM usage and load time for users running GLM-4.5/4.6 in standard generation mode.
Removes heavy penalty checks (repetition, frequency, presence, DRY) from `common_sampler_sample_speculative`. The specialized speculative sampler now uses a pure ArgMax (Greedy) approach. This significantly reduces CPU overhead during the drafting phase, which improves overall tokens per second.
|
@CISC It should be ready now. |
|
@SamuelOliveirads Hi. Just trying to test the branch, and got these warnings when building for HIP. Maybe it will help to polish PR |
|
Hi there. Could anyone provide a complete example command for llama-server? |
Thanks, I will take some time to look into these warnings.
@CHNtentes Sure! Please keep in mind that this PR is about implementing the MTP architecture and is not necessarily fully optimized yet. I have seen some users getting the same or even lower performance than the baseline. You will need to use three arguments when loading the model:
Fine-tuning Advice
Results will also vary depending on your use case; tasks like coding (where you can use greedy decoding) will likely give better results than creative writing. |
Thanks for your reply. I'll run some tests once the GLM 4.7 GGUFs are downloaded. Is this a total different implementation? |
The architecture is the same; it uses Eagle, as that is what GLM requires. For comparison, it's something like:
We have Want to replicate the recommended params from Z.ai? Just use: |
|
I tried this PR with https://huggingface.co/unsloth/GLM-4.7-GGUF/tree/main/UD-Q4_K_XL and following command as baseline: baseline decode speed: -mtp --draft 1: -mtp --draft 2: -mtp --draft 3: -mtp --draft 2 --draft-p-min 0.85: It seems the feature is working, but the speed is indeed slower than without mtp. I'll wait for future optimizations. |
|
After some tests I found it crashes for some models, like Xiaomi MiMo. |
Someone can correct me if I'm mistaken, but this is an implementation of MTP for GLM, MiMo will presumably require an extension/modification of this PR for its MTP implementation to work |
You are right, but application should not crash, it should ignore unsupported models. |
CISC
left a comment
There was a problem hiding this comment.
Let's start small and hopefully get the ball rolling...
| add_opt(common_arg( | ||
| {"-mtp", "--multi-token-prediction"}, | ||
| string_format("Activate multi-token-prediction (if supported) (default: %s)", params.mtp ? "true" : "false"), | ||
| [](common_params & params) { | ||
| params.mtp = true; | ||
| } | ||
| )); |
There was a problem hiding this comment.
| add_opt(common_arg( | |
| {"-mtp", "--multi-token-prediction"}, | |
| string_format("Activate multi-token-prediction (if supported) (default: %s)", params.mtp ? "true" : "false"), | |
| [](common_params & params) { | |
| params.mtp = true; | |
| } | |
| )); | |
| add_opt(common_arg( | |
| {"-mtp", "--multi-token-prediction"}, | |
| {"-no-mtp", "--no-multi-token-prediction"}, | |
| string_format("whether to use multi-token-prediction (if supported) (default: %s)", params.mtp ? "true" : "false"), | |
| [](common_params & params, bool value) { | |
| params.mtp = value; | |
| } | |
| )); |
src/llama-context.cpp
Outdated
| const llama_memory_context_i * mctx, | ||
| llm_graph_type gtype) const { | ||
| const llama_memory_context_i * mctx, | ||
| llm_graph_type gtype, | ||
| const llama_mtp_params & mtp_params) const { |
There was a problem hiding this comment.
Saw this a couple of times now, please don't change, and align properly, the idea is that you should quickly and easily be able to get an overview over the variable names due to vertical alignment.
src/llama-context.h
Outdated
|
|
||
| mutable int32_t n_reused = 0; // number of times the previous graph was reused | ||
| }; | ||
| }; No newline at end of file |
There was a problem hiding this comment.
Beware of missing EOF newlines, this and others will fail CI.
| if (!res.empty()) { | ||
| std::string idxs_str; | ||
| for (const auto& vec : res.idxs) { | ||
| if (!vec.empty()) { | ||
| if (vec.size() > 8) { | ||
| idxs_str += " [" + std::to_string(vec.front()) + "..." + std::to_string(vec.back()) + " (" + std::to_string(vec.size()) + " cells)]"; | ||
| } else { | ||
| idxs_str += " ["; | ||
| for(size_t i = 0; i < vec.size(); ++i) { | ||
| idxs_str += std::to_string(vec[i]) + (i == vec.size() - 1 ? "" : ", "); | ||
| } | ||
| idxs_str += "]"; | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Leftover debug, but no logging anymore?
include/llama.h
Outdated
| * @brief Removes KV cache metadata for a specified sequence and token range. | ||
| * This makes the physical cells logically available again without deleting the tensor data. | ||
| */ | ||
| LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); |
There was a problem hiding this comment.
I genuinely don't think this array of API will be accepted by other maintainers. IMO it does break a lot of pattern that we explicitly established in CONTRIBUTING.md (have you even read it before pushing this PR?)
- It doesn't make sense to make breaking change to
llama_batchby addingmtp_paramsto it.llama_set_causal_attn()andllama_set_embeddings()are already used for similar purpose. - Why
llama_kv_cache_seq_rm? What's wrong withllama_memory_seq_rm()? - What is an
sinfo? Nowhere in this file explain about it. It doesn't even have a public struct llama_set_draft_input_hidden_stateindicates that we have to manually copy embeddings from main LLM to MTP layers. This doesn't resolve the core issue brought up by Georgi's comment. Plus, this breaks the API naming convention ofllama_<module>_<verb>
Just remind that supporting MTP not NOT hard, but designing API to support all MTP models is hard.
Unless this PR invest more thoughts / more work on designing an "universal" API that supports most MTP models, I don't think we can consider merging it.
|
@F1LM1, do you want to work on these fixes, or should I handle them? |
Sure, I'll have time to look at this later this week |
This is very much a draft/proof of concept I'm playing with, just one idea for an MTP implementation. Planning to test on GLM-4.5 because it's the only model out there that we've preserved NextN tensors for.
From what I can tell
So implementation-wise it seems like
mtp_speculative_gen_draftin speculative.cpp that is vastly simplified and branch into it in server.cpp when a slot has MTP (versuscommon_speculative_gen_draft).ctx_dftin this case as well. It's a bit hacky but I was thinking we could just havectx_dft = ctxand then have both normal and MTP passes write over the sharedctxlogits. I think this minimizes required code changes elsewhereThis is my first time (1) working with ML stuff outside of python (2) attempting to contribute, so patience is appreciated :)