Skip to content

Commit 1ec9145

Browse files
committed
cont : alternative impl
1 parent 18e1fd2 commit 1ec9145

6 files changed

Lines changed: 79 additions & 101 deletions

File tree

src/llama-memory-hybrid-iswa.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "llama-impl.h"
44
#include "llama-model.h"
55
#include "llama-context.h"
6-
#include <limits>
76

87
//
98
// llama_memory_hybrid_iswa
@@ -137,10 +136,10 @@ void llama_memory_hybrid_iswa::clear(bool data) {
137136
}
138137

139138
bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
139+
// Try removing from the recurrent cache first since it may fail. If it does
140+
// fail, the cache will not have been mutated.
140141
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
141-
mem_recr->seq_rm(seq_id, 0, std::numeric_limits<llama_pos>::max());
142-
mem_attn->seq_rm(seq_id, p0, p1);
143-
return false; //This should always fail, since we cannot truncate recurrent
142+
return false;
144143
}
145144
return mem_attn->seq_rm(seq_id, p0, p1);
146145
}

src/llama-memory-recurrent.cpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
163163
const auto & cell = cells[tail_id];
164164
// partial intersection is invalid if it includes the final pos
165165
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
166-
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
166+
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
167167
return false;
168168
}
169169
// invalidate tails which will be cleared
@@ -599,21 +599,10 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
599599
// update the pos of the used seqs
600600
for (uint32_t s = 0; s < n_seqs; ++s) {
601601
const uint32_t i = s*n_seq_tokens;
602+
const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
602603
const int32_t cell_id = s + min;
603604
auto & cell = cells[cell_id];
604605

605-
// The temporal plane may have the same value for all image tokens, so we need the max across ALL planes to get the true sequence position.
606-
llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
607-
608-
// For M-RoPE image/audio embeddings,positions are stored in multiple planes. The temporal plane may have the same value for all tokens, so scan all planes for the true max.
609-
if (ubatch.n_pos > 1 && ubatch.embd != nullptr) {
610-
for (uint32_t p = 0; p < ubatch.n_pos; ++p) {
611-
for (uint32_t t = 0; t < n_seq_tokens; ++t) {
612-
last_pos = std::max(last_pos, ubatch.pos[p * ubatch.n_tokens + i + t]);
613-
}
614-
}
615-
}
616-
617606
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
618607
// What should happen when the pos backtracks or skips a value?
619608
// Clearing the state mid-batch would require special-casing which isn't done.

tools/server/server-common.cpp

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -231,19 +231,47 @@ server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) :
231231
server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
232232
}
233233

234-
llama_pos server_tokens::pos_next() const {
234+
llama_pos server_tokens::pos_next(int64_t n_tokens) const {
235235
if (!has_mtmd) {
236-
return tokens.size();
236+
if (n_tokens < 0) {
237+
return tokens.size();
238+
}
239+
240+
return n_tokens;
237241
}
238242

239-
llama_pos res = tokens.size();
243+
if (n_tokens < 0) {
244+
llama_pos res = tokens.size();
240245

241-
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
242-
const auto & chunk = it->second;
243-
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
246+
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
247+
const auto & chunk = it->second;
248+
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
249+
}
250+
251+
return res;
244252
}
245253

246-
return res;
254+
int64_t idx = 0;
255+
llama_pos pos = 0;
256+
257+
GGML_ASSERT(n_tokens <= (int64_t)tokens.size());
258+
259+
while (idx < n_tokens) {
260+
auto media_it = map_idx_to_media.find(idx);
261+
if (media_it != map_idx_to_media.end()) {
262+
const auto & chunk = media_it->second;
263+
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
264+
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
265+
266+
pos += n_pos;
267+
idx += n_tok;
268+
} else {
269+
pos++;
270+
idx++;
271+
}
272+
}
273+
274+
return pos;
247275
}
248276

249277
size_t server_tokens::tokens_up_to_pos(llama_pos max_pos) const {
@@ -252,27 +280,25 @@ size_t server_tokens::tokens_up_to_pos(llama_pos max_pos) const {
252280
}
253281

254282
size_t idx = 0;
255-
llama_pos current_pos = 0;
283+
llama_pos pos = 0;
256284

257285
while (idx < tokens.size()) {
258286
auto media_it = map_idx_to_media.find(idx);
259287
if (media_it != map_idx_to_media.end()) {
260288
const auto & chunk = media_it->second;
261289
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
262-
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
290+
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
263291

264-
if (current_pos + n_pos > max_pos + 1) {
265-
break;
266-
}
267-
current_pos += n_pos;
292+
pos += n_pos;
268293
idx += n_tok;
269294
} else {
270-
if (current_pos > max_pos) {
271-
break;
272-
}
273-
current_pos++;
295+
pos++;
274296
idx++;
275297
}
298+
299+
if (pos > max_pos) {
300+
break;
301+
}
276302
}
277303

278304
return idx;

tools/server/server-common.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,14 @@ struct server_tokens {
167167
// for debugging
168168
std::string str() const;
169169

170-
llama_pos pos_next() const;
171-
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
170+
// the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
171+
llama_pos pos_next(int64_t n_tokens = -1) const;
172172

173+
// number of tokens with position <= max_pos
173174
size_t tokens_up_to_pos(llama_pos max_pos) const;
174-
175+
176+
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
177+
175178
void push_back(llama_token tok);
176179

177180
// will create a copy of the chunk if it contains non-text data

tools/server/server-context.cpp

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ struct server_slot {
7777
size_t last_nl_pos = 0;
7878

7979
std::string generated_text;
80+
std::string debug_generated_text;
8081
llama_tokens generated_tokens;
8182

8283
// idx of draft tokens in the main batch
@@ -425,7 +426,7 @@ struct server_slot {
425426

426427
if (!only_metrics) {
427428
res["prompt"] = ptask->tokens.detokenize(ctx, true);
428-
res["generated"] = generated_text;
429+
res["generated"] = generated_text.empty() ? debug_generated_text : generated_text;
429430
}
430431
}
431432

@@ -1441,7 +1442,13 @@ struct server_context_impl {
14411442
res->id = slot.task->id;
14421443
res->id_slot = slot.id;
14431444

1444-
res->index = slot.task->index;
1445+
res->index = slot.task->index;
1446+
1447+
// keep copy of last generated text for debugging purposes
1448+
if (slots_debug) {
1449+
slot.debug_generated_text = slot.generated_text;
1450+
}
1451+
14451452
// in stream mode, content and tokens are already in last partial chunk
14461453
if (slot.task->params.stream) {
14471454
res->content = "";
@@ -2275,14 +2282,14 @@ struct server_context_impl {
22752282
n_past = 0;
22762283
}
22772284

2285+
llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
2286+
22782287
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
22792288
const auto n_swa = std::max(1, llama_model_n_swa(model));
22802289

22812290
// the largest pos_min required for a checkpoint to be useful
2282-
const auto pos_min_thold = std::max(0, n_past - n_swa);
2291+
const auto pos_min_thold = std::max(0, pos_next - n_swa);
22832292

2284-
// note: disallow with mtmd contexts for now
2285-
// https://github.com/ggml-org/llama.cpp/issues/17043
22862293
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
22872294
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
22882295
if (pos_min == -1) {
@@ -2334,9 +2341,6 @@ struct server_context_impl {
23342341
}
23352342

23362343
if (pos_min > pos_min_thold) {
2337-
// Removed assert. This is a partial fix
2338-
2339-
23402344
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
23412345

23422346
// search for a context checkpoint
@@ -2361,14 +2365,16 @@ struct server_context_impl {
23612365
do_reset = true;
23622366
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
23632367
} else {
2364-
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
2368+
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
2369+
n_past = slot.prompt.tokens.tokens_up_to_pos(pos_next);
23652370
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
23662371
}
23672372
}
23682373

23692374
if (do_reset) {
23702375
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
23712376
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
2377+
pos_next = 0;
23722378
n_past = 0;
23732379
}
23742380
}
@@ -2395,17 +2401,10 @@ struct server_context_impl {
23952401
SLT_WRN(slot, "n_past was set to %d\n", n_past);
23962402
}
23972403

2398-
2404+
slot.n_prompt_tokens_cache = n_past;
23992405
slot.n_prompt_tokens_processed = 0;
24002406

2401-
if (slot.prompt.tokens.has_mtmd) {
2402-
const int n_tokens_keep = (int)slot.prompt.tokens.tokens_up_to_pos(n_past);
2403-
slot.n_prompt_tokens_cache = n_tokens_keep;
2404-
slot.prompt.tokens.keep_first(n_tokens_keep);
2405-
} else {
2406-
slot.n_prompt_tokens_cache = n_past;
2407-
slot.prompt.tokens.keep_first(n_past);
2408-
}
2407+
slot.prompt.tokens.keep_first(n_past);
24092408

24102409
// send initial 0% progress update if needed
24112410
// this is to signal the client that the request has started processing
@@ -2427,53 +2426,14 @@ struct server_context_impl {
24272426
SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0);
24282427

24292428
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) {
2430-
// hybrid model: recurrent partial removal failed.
2431-
// find a checkpoint to restore recurrent state from,
2432-
// then truncate attention KV to checkpoint position (preserving image KV).
2433-
bool recovered = false;
2434-
2435-
if (!slot.prompt.checkpoints.empty()) {
2436-
for (auto it = slot.prompt.checkpoints.rbegin(); it != slot.prompt.checkpoints.rend(); ++it) {
2437-
if (std::max(it->pos_min, it->pos_max) >= p0) {
2438-
continue; // checkpoint is past truncation point
2439-
}
2440-
2441-
// truncate attention KV to checkpoint position (and clear recurrent).
2442-
// this call will "fail" (return false) because recurrent can't do
2443-
// partial removal, but the hybrid seq_rm internally handles it:
2444-
// - clears recurrent fully
2445-
// - truncates attention from checkpoint pos_max onward
2446-
const llama_pos checkpoint_pos = std::max(it->pos_min, it->pos_max);
2447-
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, checkpoint_pos, -1);
2429+
SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0);
24482430

2449-
const size_t checkpoint_size = it->data.size();
2450-
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
2431+
slot.prompt_clear(true);
24512432

2452-
if (n == checkpoint_size) {
2453-
const int n_past_new = (int)slot.prompt.tokens.tokens_up_to_pos(checkpoint_pos);
2454-
2455-
SLT_WRN(slot, "recovered recurrent state from checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d), n_past: %d -> %d\n",
2456-
it->pos_min, it->pos_max, it->n_tokens_cached, slot.prompt.n_tokens(), n_past_new);
2457-
2458-
slot.prompt.tokens.keep_first(n_past_new);
2459-
slot.n_prompt_tokens_cache = n_past_new;
2460-
recovered = true;
2461-
break;
2462-
}
2463-
}
2464-
}
2465-
2466-
if (!recovered) {
2467-
SLT_WRN(slot, "failed to recover recurrent state - clearing the memory%s\n", "");
2468-
2469-
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
2470-
2471-
auto saved_checkpoints = std::move(slot.prompt.checkpoints);
2472-
slot.prompt_clear(true);
2473-
slot.n_prompt_tokens_cache = 0;
2474-
slot.prompt.checkpoints = std::move(saved_checkpoints);
2475-
}
2433+
// there is no common part left
2434+
slot.n_prompt_tokens_cache = 0;
24762435
}
2436+
24772437
// check if we should process the image
24782438
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
24792439
// process the image
@@ -2604,7 +2564,6 @@ struct server_context_impl {
26042564
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
26052565
/*.pos_min = */ pos_min,
26062566
/*.pos_max = */ pos_max,
2607-
/*.n_tokens_cached = */ slot.prompt.n_tokens(),
26082567
/*.data = */ std::vector<uint8_t>(checkpoint_size),
26092568
});
26102569

@@ -2951,6 +2910,9 @@ server_context_meta server_context::get_meta() const {
29512910
/* fim_pre_token */ llama_vocab_fim_pre(impl->vocab),
29522911
/* fim_sub_token */ llama_vocab_fim_suf(impl->vocab),
29532912
/* fim_mid_token */ llama_vocab_fim_mid(impl->vocab),
2913+
/* fim_pad_token */ llama_vocab_fim_pad(impl->vocab),
2914+
/* fim_rep_token */ llama_vocab_fim_rep(impl->vocab),
2915+
/* fim_sep_token */ llama_vocab_fim_sep(impl->vocab),
29542916

29552917
/* model_vocab_type */ llama_vocab_type(impl->vocab),
29562918
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),

tools/server/server-task.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,6 @@ struct server_task_result_apply_lora : server_task_result {
556556
struct server_prompt_checkpoint {
557557
llama_pos pos_min;
558558
llama_pos pos_max;
559-
int n_tokens_cached;
560559

561560
std::vector<uint8_t> data;
562561

0 commit comments

Comments
 (0)