[bnb] We should be able to run 8-bit models on CPU & GPU#20281
[bnb] We should be able to run 8-bit models on CPU & GPU#20281younesbelkada wants to merge 2 commits into
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
sgugger
left a comment
There was a problem hiding this comment.
Mmm, this is a bit misleading as this will result in weights offloaded on the CPU to not be converted in int8 at all.
|
Yes, this is true, maybe I can add a warning telling the user about the underlying behaviour? (weights offloaded on CPU will remain in their native precision) |
|
Or we could just leave the error? |
|
Closing this PR as it will bring confusion to users, we should probably wait until |
|
Currently we can pass However, there's a problem here: transformers/src/transformers/utils/bitsandbytes.py Lines 113 to 117 in 61d3928 The Now the thing is - this PR actually almost solves this problem. It modifies So, what I propose: convert this PR into a different one that will allow specifying a full path in |
|
So, if anyone is still interested, this is how I implemented the above solution in my project (via monkey-patching): # This is a modified version of replace_8bit_linear in transformers/utils/bitsandbytes.py
# The following changes were made:
# 1. modules_to_not_convert can contain full module paths instead of just immediate names
# 2. the default value for modules_to_not_convert is effectively a list instead of a string
# 3. "model" is renamed to "parent_module" to not confuse it with the actual model
# 4. removed redundant check for len(modules)
def replace_8bit_linear(parent_module, threshold=6.0, modules_to_not_convert=None, parent_layer_path=""):
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
parent_layer_prefix = "" if parent_layer_path == "" else parent_layer_path + "."
for name, module in parent_module.named_children():
layer_path = parent_layer_prefix + name
if layer_path in modules_to_not_convert:
continue
replace_8bit_linear(module, threshold, modules_to_not_convert, layer_path)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
with bitsandbytes.init_empty_weights():
parent_module._modules[name] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=threshold,
)
return parent_moduleI implemented it a little bit differently than @younesbelkada did, though, and also applied some other small modifications. Then here's the code that gets the layers that need to be ignored: def get_modules_to_skip_for_int8(device_map: DeviceMap) -> Optional[list[str]]:
layer_paths = [path for path, device in device_map.items() if device == DEVICE_CPU]
# adding lm_head based on comment from get_keys_to_not_convert in transformers/utils/bitsandbytes.py
# which says "for CausalLM modules we may want to keep the lm_head in full precision"
return layer_paths + ["lm_head"]In my case I only offload to CPU, not disk. These layers will then be passed as I tested everything and it works well. The only thing I'm not sure about is if the new |
|
@younesbelkada good job!!! I used your PR + @z80maniac tips and code samples and I managed to load a big model and run 8bit inference using some gpus and big amount of cpu RAM. I think your PR must be merged or at least mainatained in a separate branch into HF transfomers because I don't believe Everything works great although the inference was kinda slow which is expected when using both GPUs and CPU RAM? I have 2 ideas on how to speed things up a little:
|
|
I tried @z80maniac 's suggestions and while I didn't run into any runtime errors, weights that were supposed to be in fp32 ended up in fp16 (Flan T5 automatically keeps |
How do you know that they're in fp16 and not in fp32? What dtype did you specify if any? |
|
I manually checked the dtypes with
It stays in fp16 whether I pass in Also unfortunately I don't think Flan T5 XXL can handle fp16 or 8-bit precision (see #20683). I'll follow up a bit later because I noticed there was a bug in the code I was testing, though I don't think it should have affected whether or not the weights stayed in fp32 (in fact the bug, if tripped, would have just caused an OOM error instead). |
|
Ah I wasn't aware of that issue with Flan T5 though I am sure that i have loaded it with torch_dtype=torch.float16 in the past and have not noticed any serios performance degradation though the difference maybe was too subtle to notice.... |
|
Yeah iirc the issue is only with with XXL variant, I think the other variants should run with 16-bit/8-bit quantization just fine. My plan is very close to what you mentioned: I was going to offload I should point out though that 1) even when I did get them offloaded in fp32, I got |
|
from my experience some of those errors can be turned into just a warning (by monkey-patching the python source code) and everything will still work properly |
|
BTW the Flan T5 XXL HF page also has examples of using fp16 and int8. It's possible that the Google team hasn't really tested its performance using quantization though... |
|
I just tried
|
Oh that's great! Personally haven't tried fp16 myself; I can attest to poor results on int8, but I was just going off of other issues/discussions regarding the performance in fp16 (e.g. #20287 (comment)) (EDIT: this issue is from before the relevant patches were merged/when all the weights were in fp16). Just curious though, could you check what dtype the fp16 Flan T5 XXL has its On my end I'll still going to take a look into my issues with offloading into fp32 for completeness' sake. |
|
I now loaded T5 XXL using int8. And again got a satisfying response |
|
Yup, this is actually expected. There's no problem loading the other weights in 8-bit, as long as the My personal problem is that |
|
After running a few simple prompts, I can't see any difference in int8 output when compared to fp32 or fp16. If you have a more sophisticated prompt you wish to try, lmk |
Are you sure you're using the latest versions of |
|
I've been installing I'm working with around 12.7 GB CPU RAM and 15 GB VRAM (standard Colab GPU runtime, Tesla T4). I'll play around with the kwargs to |
|
Ah sorry after rereading my comments I think I've been unclear with what I meant by performance degradation in int8. When we say that Flan T5 XXL can't handle 8-bit, we mean that we can't quantize every single parameter to 8-bit precision the way we traditionally would with a standard T5 (a little misleading to say that since I think My personal problem is that my VRAM isn't large enough to host the 8-bit quantized non- Two further issues arise from this:
On a side note, while writing this I tried loading the model with a device map that put the |
|
got it. thanks. |
|
Thanks for this! I'll take a look into it. I appreciate your help with all of this. |
|
Actually I was trying to use mT0 XXL for a different research project a while back, but I had difficulties just trying to load it into memory. But thanks for prompting me to take another look; I was reviewing my notebook to try to refresh my memory on what I tried and I'm only now seeing I never set |
|
FYI here is what the model card of |
|
cc @Muennighoff if you have any idea what might be wrong here 🙏 |
|
@younesbelkada BTW did you read my comment above with some questions / suggestions? What do you think? |
|
@alexconstant9108 Do you also get the same pad output for mT0 in FP32? |
That's very odd - Do you get the same with mt0-small? Here's what I get: !pip install -q transformers accelerate
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
checkpoint = "bigscience/mt0-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto", offload_folder="./")
inputs = tokenizer.encode("Translate to English: Je t’aime.", return_tensors="pt").to("cuda")
outputs = model.generate(inputs)
print(tokenizer.decode(outputs[0]))
<pad> I love you.</s> |
|
@alexconstant9108 Just wanted to let you know I'm still having trouble loading mT0 XXL into memory. Maybe the shard sizes are too big? Not sure; sorry I can't verify your results |
I'm getting the same with Can you double-check your environment? |
|
It looks like your Mine & the uploaded ones are: |
Sorry to bump this thread @alexconstant9108 but would you mind running this prompt?: Splitting the model strictly between GPU and CPU (no disk involved) seemed to fix my problems in terms of getting the |
|
Hi @alexconstant9108 , Thanks for your interest in this feature. I propose to slightly refactor the API in #21579 and enable the feature you have asked for! Feel free to share any thoughts there |
|
@ryan-caesar-ramos after loading all weights as fp32, I also get the same output as you: LLMs generally suck even at basic Math(arithmetic included) so the above error is not surprising especially for a small model like flan-t5-xxl. I think the only remedy for this issue is using a different architecture (not transformer based) or adding the ability to the model to call external tools e.g. a calculator app. I haven't tried yet but I suspect that even ChatGPT will mess up a puzzle like the above |
|
@ryan-caesar-ramos you may want to give FlexGen a try when loading big models: https://github.com/FMInference/FlexGen |
|
Thanks! Will check it out |
What does this PR do?
This PR adds the possibility of using a custom device map containing CPU and GPU devices when loading and running 8-bit models. This is useful in the context of large models, if someone wants to offload part of the model on
cpuor on thedisk.Added also slow tests to test this feature, let me know if you think that I am missing any corner case.
cc @sgugger
closes #19090