Better defaults for assisted generation#40976
Conversation
| self.logits_processor = [ | ||
| processor for processor in self.logits_processor if not isinstance(processor, MinLengthLogitsProcessor) | ||
| ] |
There was a problem hiding this comment.
Length is controlled by main model's generation loop, so we should just discard those on the assistant right? @gante
There was a problem hiding this comment.
yes 👍
(see comment on L175-176)
There was a problem hiding this comment.
So we essentially remove the thrown error? Not sure if this is really relevant to this PR, more of a shortener no?
| # Prefer a slightly higher temperature for the assistant when not explicitly provided | ||
| idx = next((i for i, p in enumerate(logits_processor) if isinstance(p, TemperatureLogitsWarper)), None) | ||
| temp_processor = logits_processor.pop(idx) if idx is not None else TemperatureLogitsWarper(temperature=1.0) | ||
|
|
||
| if assistant_temperature is None and temp_processor is not None and temp_processor.temperature < 1.5: | ||
| logger.warning_once( | ||
| f"The assistant's sampling temperature comes from main generation loop set to {temp_processor.temperature}," | ||
| "but speculative decoding benefits from slightly hotter candidate generation, (see #40976)so we are setting it" | ||
| "to 1.5. This should improve decoding speed in most cases. Use `assistant_temperature` to override this value." | ||
| ) | ||
| assistant_temperature = 1.5 | ||
|
|
||
| if assistant_temperature is not None: | ||
| logits_processor.insert(0, TemperatureLogitsWarper(temperature=assistant_temperature)) |
There was a problem hiding this comment.
1.5 seems the most balanced for now, subject to change if experiments show otherwise, I am still benchmarking more models.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| "assistant_model": assistant_model, | ||
| "streamer": streamer, | ||
| "assistant_temperature": kwargs.pop("assistant_temperature", None), | ||
| } | ||
| generation_mode_kwargs["synced_gpus"] = ( |
There was a problem hiding this comment.
no need to change generate signature!! it gets automatically forwarded.
In fact, we could remove assistant_model from the signature (👀 v5?) and all the decoding method-specific kwargs get automatically forwarded.
There was a problem hiding this comment.
no signature change, but it's still an argument (that should be documented)
In any case, I'd rather have it being controlled by assistant_model.generation_config.temperature, in AssistedCandidateGenerator.__init__ -- if it's the default value (1.0), or == main model temperature, then override.
There was a problem hiding this comment.
+1, it should lie within the assistant's generation config if possible. That would be cleaner
| self.logits_processor = [ | ||
| processor for processor in self.logits_processor if not isinstance(processor, MinLengthLogitsProcessor) | ||
| ] |
There was a problem hiding this comment.
yes 👍
(see comment on L175-176)
| # Prefer a slightly higher temperature for the assistant when not explicitly provided | ||
| idx = next((i for i, p in enumerate(logits_processor) if isinstance(p, TemperatureLogitsWarper)), None) | ||
| temp_processor = logits_processor.pop(idx) if idx is not None else TemperatureLogitsWarper(temperature=1.0) | ||
|
|
||
| if assistant_temperature is None and temp_processor is not None and temp_processor.temperature < 1.5: | ||
| logger.warning_once( | ||
| f"The assistant's sampling temperature comes from main generation loop set to {temp_processor.temperature}, " | ||
| "but speculative decoding benefits from slightly hotter candidate generation, (see #40976) so we are setting it " | ||
| "to 1.5. This should improve decoding speed in most cases. Use `assistant_temperature` to override this value." | ||
| ) | ||
| assistant_temperature = 1.5 | ||
|
|
||
| if assistant_temperature is not None: | ||
| logits_processor.insert(0, TemperatureLogitsWarper(temperature=assistant_temperature)) |
There was a problem hiding this comment.
doesn't this change the temperature for both models? 👀 (logits_processor also used in step 2.3)
There was a problem hiding this comment.
Yea that's a good question, we only up the base temperature no? We could also just modify the temperature in place if that's the case
| "assistant_model": assistant_model, | ||
| "streamer": streamer, | ||
| "assistant_temperature": kwargs.pop("assistant_temperature", None), | ||
| } | ||
| generation_mode_kwargs["synced_gpus"] = ( |
There was a problem hiding this comment.
no signature change, but it's still an argument (that should be documented)
In any case, I'd rather have it being controlled by assistant_model.generation_config.temperature, in AssistedCandidateGenerator.__init__ -- if it's the default value (1.0), or == main model temperature, then override.
vasqu
left a comment
There was a problem hiding this comment.
I'm a bit confused whether the assistant and main model really use different logits processors as they use the prepared_logits_processor and whether we could just modify in place then if that's the case. I.e. the main model in general benefits from a higher temperature.
Generally aligned with getting better defaults tho. Just a bit confused if what happens is really what happens per the current comments.
| self.logits_processor = [ | ||
| processor for processor in self.logits_processor if not isinstance(processor, MinLengthLogitsProcessor) | ||
| ] |
There was a problem hiding this comment.
So we essentially remove the thrown error? Not sure if this is really relevant to this PR, more of a shortener no?
| "assistant_model": assistant_model, | ||
| "streamer": streamer, | ||
| "assistant_temperature": kwargs.pop("assistant_temperature", None), | ||
| } | ||
| generation_mode_kwargs["synced_gpus"] = ( |
There was a problem hiding this comment.
+1, it should lie within the assistant's generation config if possible. That would be cleaner
| # Prefer a slightly higher temperature for the assistant when not explicitly provided | ||
| idx = next((i for i, p in enumerate(logits_processor) if isinstance(p, TemperatureLogitsWarper)), None) | ||
| temp_processor = logits_processor.pop(idx) if idx is not None else TemperatureLogitsWarper(temperature=1.0) | ||
|
|
||
| if assistant_temperature is None and temp_processor is not None and temp_processor.temperature < 1.5: | ||
| logger.warning_once( | ||
| f"The assistant's sampling temperature comes from main generation loop set to {temp_processor.temperature}, " | ||
| "but speculative decoding benefits from slightly hotter candidate generation, (see #40976) so we are setting it " | ||
| "to 1.5. This should improve decoding speed in most cases. Use `assistant_temperature` to override this value." | ||
| ) | ||
| assistant_temperature = 1.5 | ||
|
|
||
| if assistant_temperature is not None: | ||
| logits_processor.insert(0, TemperatureLogitsWarper(temperature=assistant_temperature)) |
There was a problem hiding this comment.
Yea that's a good question, we only up the base temperature no? We could also just modify the temperature in place if that's the case
#40657 inadvertently changed an implicit algorithmic bias: candidate_generator (the assistant model) was getting logits_processor while the decoding method (main model) was getting prepared_logits_processor. This meant that the assistant was running with T=1 while the main model was using lower temp.
We investigated and its good for speculation to have a hotter assistant model (so it was a good bug that we were not applying the lower temp to the assistant),
But it should be explicitly set and not a hidden argument forwarding consequence. This PR does that, setting it by default to 1.5.
This PR also fixes:
which originates from the same change in LogitsProcessor passing.