Skip to content

[bnb] We should be able to run 8-bit models on CPU & GPU#20281

Closed
younesbelkada wants to merge 2 commits into
huggingface:mainfrom
younesbelkada:bnb_add_custom_map
Closed

[bnb] We should be able to run 8-bit models on CPU & GPU#20281
younesbelkada wants to merge 2 commits into
huggingface:mainfrom
younesbelkada:bnb_add_custom_map

Conversation

@younesbelkada

@younesbelkada younesbelkada commented Nov 16, 2022

Copy link
Copy Markdown
Contributor

What does this PR do?

This PR adds the possibility of using a custom device map containing CPU and GPU devices when loading and running 8-bit models. This is useful in the context of large models, if someone wants to offload part of the model on cpu or on the disk.
Added also slow tests to test this feature, let me know if you think that I am missing any corner case.

cc @sgugger

closes #19090

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@younesbelkada younesbelkada changed the title We should run 8-bit models on CPU & GPU [bnb] We should run 8-bit models on CPU & GPU Nov 16, 2022
@younesbelkada younesbelkada changed the title [bnb] We should run 8-bit models on CPU & GPU [bnb] We should be able to run 8-bit models on CPU & GPU Nov 16, 2022

@sgugger sgugger left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm, this is a bit misleading as this will result in weights offloaded on the CPU to not be converted in int8 at all.

@younesbelkada

Copy link
Copy Markdown
Contributor Author

Yes, this is true, maybe I can add a warning telling the user about the underlying behaviour? (weights offloaded on CPU will remain in their native precision)

@sgugger

sgugger commented Nov 16, 2022

Copy link
Copy Markdown
Collaborator

Or we could just leave the error?

@younesbelkada

Copy link
Copy Markdown
Contributor Author

Closing this PR as it will bring confusion to users, we should probably wait until bitsandbytes supports weights offloading in 8-bit to add this feature
Thanks!

@z80maniac

Copy link
Copy Markdown

Currently we can pass load_in_8bit_skip_modules into model_kwargs. This will allow to not convert certain layers/modules/weights into 8-bit.

However, there's a problem here:

for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold, modules_to_not_convert)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:

The name is not the full path to the module, e.g. for transformer.h.0.ln1 it can be 0 or ln1, etc, depending on the recursion level. So currently it's impossible to ignore a specific layer or a group of layers. For example, if transformer.h.0 is on CPU, then I don't want it (and any of its sub-layers) to be converted to 8-bit, but I can't specify this layer in load_in_8bit_skip_modules. Furthermore, even specifying 0 won't help, because child sub-layers (e.g. *.0.ln1) are processed first.

Now the thing is - this PR actually almost solves this problem. It modifies replace_8bit_linear so that it can handle ignoring modules by the full path, not just by the immediate name.

So, what I propose: convert this PR into a different one that will allow specifying a full path in load_in_8bit_skip_modules. This will allow to manually ignore non-GPU layers when needed and it will not confuse the users.

@z80maniac

Copy link
Copy Markdown

So, if anyone is still interested, this is how I implemented the above solution in my project (via monkey-patching):

https://github.com/alkatrazstudio/neodim-server/blob/93e4819d3633841ca4f42246f51e28f355ed6cf5/src/bnb_override.py#L9-L37

# This is a modified version of replace_8bit_linear in transformers/utils/bitsandbytes.py
# The following changes were made:
# 1. modules_to_not_convert can contain full module paths instead of just immediate names
# 2. the default value for modules_to_not_convert is effectively a list instead of a string
# 3. "model" is renamed to "parent_module" to not confuse it with the actual model
# 4. removed redundant check for len(modules)
def replace_8bit_linear(parent_module, threshold=6.0, modules_to_not_convert=None, parent_layer_path=""):
    modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert

    parent_layer_prefix = "" if parent_layer_path == "" else parent_layer_path + "."
    for name, module in parent_module.named_children():
        layer_path = parent_layer_prefix + name

        if layer_path in modules_to_not_convert:
            continue

        replace_8bit_linear(module, threshold, modules_to_not_convert, layer_path)

        if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
            with bitsandbytes.init_empty_weights():
                parent_module._modules[name] = bnb.nn.Linear8bitLt(
                    module.in_features,
                    module.out_features,
                    module.bias is not None,
                    has_fp16_weights=False,
                    threshold=threshold,
                )

    return parent_module

I implemented it a little bit differently than @younesbelkada did, though, and also applied some other small modifications.

Then here's the code that gets the layers that need to be ignored:

https://github.com/alkatrazstudio/neodim-server/blob/93e4819d3633841ca4f42246f51e28f355ed6cf5/src/dev_map.py#L142-L147

def get_modules_to_skip_for_int8(device_map: DeviceMap) -> Optional[list[str]]:
    layer_paths = [path for path, device in device_map.items() if device == DEVICE_CPU]

    # adding lm_head based on comment from get_keys_to_not_convert in transformers/utils/bitsandbytes.py
    # which says "for CausalLM modules we may want to keep the lm_head in full precision"
    return layer_paths + ["lm_head"]

In my case I only offload to CPU, not disk.

These layers will then be passed as load_in_8bit_skip_modules to the from_pretrained method.

I tested everything and it works well. The only thing I'm not sure about is if the new replace_8bit_linear actually backwards-compatible with the old version. It's compatible when modules_to_not_convert=["lm_head"], but I'm not sure about the generic use-case.

@alexconstant9108

Copy link
Copy Markdown

@younesbelkada good job!!! I used your PR + @z80maniac tips and code samples and I managed to load a big model and run 8bit inference using some gpus and big amount of cpu RAM. I think your PR must be merged or at least mainatained in a separate branch into HF transfomers because I don't believe bitsandbytes will ever implement CPU offloading in their project. I read such opinions among the issues list in their project....

Everything works great although the inference was kinda slow which is expected when using both GPUs and CPU RAM?

I have 2 ideas on how to speed things up a little:

  1. It looks like the fp16 weights (offloaded to the CPU RAM) get copied back and forth on every pass into the VRAM of the first? GPU to do the calcultions? If true, then we may as well store those weights into 8bit on the CPU RAM from the start in order to avoid converting from fp16 into int8 and to also decrease the CPU RAM requirements by half?
  2. Another approach would be to perform the calculations on those fp16 weights right using the available CPU cores and thus avoid copying back and forth all of the weights into GPU VRAM on every pass?
    Does any of the above make any sense?

@ryan-caesar-ramos

Copy link
Copy Markdown

I tried @z80maniac 's suggestions and while I didn't run into any runtime errors, weights that were supposed to be in fp32 ended up in fp16 (Flan T5 automatically keeps wo layers in fp32, which didn't happen after applying the monkey patches). Is this expected?

@alexconstant9108

Copy link
Copy Markdown

I tried @z80maniac 's suggestions and while I didn't run into any runtime errors, weights that were supposed to be in fp32 ended up in fp16 (Flan T5 automatically keeps wo layers in fp32, which didn't happen after applying the monkey patches). Is this expected?

How do you know that they're in fp16 and not in fp32? What dtype did you specify if any?
BTW fp16 should be OK too. there will barely be any performance degradation especially given that most of the weights are stored in int8 anyway

@ryan-caesar-ramos

Copy link
Copy Markdown

I manually checked the dtypes with

model.encoder.block[1].layer[1].DenseReluDense.wo.weight.dtype, model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype

It stays in fp16 whether I pass in torch_dtype=torch.float32 or leave the kwarg untouched.

Also unfortunately I don't think Flan T5 XXL can handle fp16 or 8-bit precision (see #20683). T5ForConditionalGeneration has a _keep_in_fp32_modules attribute that's supposed to help the wo layers stay in fp32, but I noticed that the suggested monkey patches might be interfering.

I'll follow up a bit later because I noticed there was a bug in the code I was testing, though I don't think it should have affected whether or not the weights stayed in fp32 (in fact the bug, if tripped, would have just caused an OOM error instead).

@alexconstant9108

Copy link
Copy Markdown

Ah I wasn't aware of that issue with Flan T5 though I am sure that i have loaded it with torch_dtype=torch.float16 in the past and have not noticed any serios performance degradation though the difference maybe was too subtle to notice....
What is your plan when loading it? wo layers in fp32 into CPU RAM and any other weights in int8 into VRAM? What does ur device map look like?
BTW i am only using the code changes by @younesbelkada from this pr
@z80maniac are a bit different though still helpful

@ryan-caesar-ramos

Copy link
Copy Markdown

Yeah iirc the issue is only with with XXL variant, I think the other variants should run with 16-bit/8-bit quantization just fine.

My plan is very close to what you mentioned: I was going to offload wo layers onto CPU RAM/disk (though I've only been experimenting with disk so far) and keep the others in int8 on GPU. I'll share my device map soon but I put myself into a bit of a situation rn. After some code changes I'm currently unable to reproduce the whole "wo-layers-are-offloaded-onto-disk-in-fp32 " thing, so I need to fix that first. I'll follow up once I figure that out.

I should point out though that 1) even when I did get them offloaded in fp32, I got RuntimeError: Expected all tensors to be on the same device, but found at least two devices, meta and cuda:0! anyway and 2) there's a chance the problem is on my end and not with any of the monkey patches, so I really should fix a few things first.

@alexconstant9108

Copy link
Copy Markdown

from my experience some of those errors can be turned into just a warning (by monkey-patching the python source code) and everything will still work properly

@alexconstant9108

alexconstant9108 commented Feb 6, 2023

Copy link
Copy Markdown

BTW the Flan T5 XXL HF page also has examples of using fp16 and int8. It's possible that the Google team hasn't really tested its performance using quantization though...
If you give me some example prompts, I will try to reproduce the results locally?

@alexconstant9108

Copy link
Copy Markdown

I just tried device_map="auto", torch_dtype=torch.float16 on multiple gpus:
translate English to German: How old are you?

<pad> Wie alt sind Sie?</s>

@ryan-caesar-ramos

ryan-caesar-ramos commented Feb 6, 2023

Copy link
Copy Markdown

I just tried device_map="auto", torch_dtype=torch.float16 on multiple gpus: translate English to German: How old are you?

<pad> Wie alt sind Sie?</s>

Oh that's great! Personally haven't tried fp16 myself; I can attest to poor results on int8, but I was just going off of other issues/discussions regarding the performance in fp16 (e.g. #20287 (comment)) (EDIT: this issue is from before the relevant patches were merged/when all the weights were in fp16).

Just curious though, could you check what dtype the fp16 Flan T5 XXL has its wo layers in? If I'm not mistaken, unless you manually disable it (i.e. T5ForConditionalGeneration.__keep_fp32_modules = None) it should set the wo layers to fp32. But if they're really all in fp16 now that's great!

On my end I'll still going to take a look into my issues with offloading into fp32 for completeness' sake.

@alexconstant9108

Copy link
Copy Markdown

I now loaded T5 XXL using int8.
print(model.encoder.block[1].layer[1].DenseReluDense.wo.weight.dtype) torch.float32 print(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype) torch.float32

And again got a satisfying response
<pad> Wie alt sind Sie?</s>

@ryan-caesar-ramos

Copy link
Copy Markdown

Yup, this is actually expected. There's no problem loading the other weights in 8-bit, as long as the wo layers are in fp32 as shown above.

My personal problem is that device_map="auto" is acting strangely for me (perhaps it's calculating on the assumption that the wo layers are in 8-bit when in fact they'll be loaded in 32-bit, which causes the OOM error) so I've been making custom device maps in the meantime. The farthest I've gotten is the runtime error I mentioned previously regarding the different devices (one of them being a meta device), but I've yet to recreate that because I changed my code at some point and need to figure out how to get it back to how it used to be. I thought the monkey patches here might help, but I'm starting to think the breaking changes I made were done before I tried the monkey patches, resulting the persistent fp16 offloaded weights that come up even without the monkey patches.

@alexconstant9108

Copy link
Copy Markdown

After running a few simple prompts, I can't see any difference in int8 output when compared to fp32 or fp16. If you have a more sophisticated prompt you wish to try, lmk

@alexconstant9108

alexconstant9108 commented Feb 6, 2023

Copy link
Copy Markdown

Yup, this is ...

Are you sure you're using the latest versions of transformers, accelerate and bnb? Perhaps, first uninstall everything you currently have, then install the above python packages and reapply the monkey patches on top of the latest versions.
Also, what is your hardware setup like? How much GPU VRAM in total, CPU RAM? You may want to try also setting the max_memory map per GPU device (but that requires some tweaking and is card / model dependent). Also, even if you get it running, keep in mind that offloading to SSD makes things reaaaaly slow. Offloading to just CPU RAM is a bit better

@ryan-caesar-ramos

Copy link
Copy Markdown

I've been installing transformers and accelerate from source, but yeah I haven't tried installing bnb from source, I'll try that.

I'm working with around 12.7 GB CPU RAM and 15 GB VRAM (standard Colab GPU runtime, Tesla T4). I'll play around with the kwargs to infer_auto_device_map but I still have a hunch that infer_auto_device_map has no way of knowing that the wo modules will end up being larger than what it currently is accounting for. And yeah you're right I definitely should offload to CPU, I've just been spending the past few days trying to get the weights to be stored in fp32 first (even before the monkey patches, I previously kept on having the wo weights in fp16. I fixed it earlier but then ended up breaking it again).

@ryan-caesar-ramos

Copy link
Copy Markdown

Ah sorry after rereading my comments I think I've been unclear with what I meant by performance degradation in int8.

When we say that Flan T5 XXL can't handle 8-bit, we mean that we can't quantize every single parameter to 8-bit precision the way we traditionally would with a standard T5 (a little misleading to say that since I think lm_head layers also can't be in 8-bit); doing so leads to poor performance. The solution transformers implemented was to do standard 8-bit quantization everywhere except the wo layers, since those were the only layers that needed to be in 32-bit. If you do that, you get the expected full performance, as you've demonstrated with your examples.

My personal problem is that my VRAM isn't large enough to host the 8-bit quantized non-wo modules alongside the 32-bit wo modules (if the entire model was 8-bit quantized, I could though). Because of that I need to offload some weights. As you probably already know, the main branches of bitsandbytes and transformers are currently a little weird when it comes to using 8-bit quantization alongside offloading. You can offload but the offloaded weights won't be in 8-bit. That's why I decided to just offload the wo layers only, since they shouldn't be in 8-bit anyway.

Two further issues arise from this:

  1. auto_device_map=True doesn't work well because when infer_auto_device_map receives dtype=torch.int8, it calculates device allocation as if everything will be in int8, consequently underestimating how much space a wo module will take up, which ultimately leads to memory errors.
  2. We can avoid the above problem if we define our own device map. The furthest I've gotten with this however is the runtime error about the different devices.

On a side note, while writing this I tried loading the model with a device map that put the wo layers on CPU instead of the disk. Surprisingly it fit, and unsurprisingly the session crashed (out of RAM) when I tried to use the model. On a somewhat more positive note though, before it crashed I noticed that it was back in fp32, which was nice. Still unsure what's causing the jump back-and-forth between fp32 and fp16, but I'm still looking into it.

@alexconstant9108

Copy link
Copy Markdown

got it. thanks.
In this PR there is a piece of code which specifies which modules to skip. Just specify lm_head and the wo layers there (+ any others as needed) or use a custom device_map. Then set max_memory for your single(?) GPU to the GPU VRAM - ~1.3-2.2GB (it will take a few attempts to get this amount to an optimal point).
From what I gather, you should be able to host around 13.4 GB of int8 weights on the Tesla GPU and the rest (in fp32) onto CPU RAM + SSD.
Flan XXL may turn out to be too big for your setup though - meaning that at least 4-5 GB will likely end up on the SSD

@ryan-caesar-ramos

Copy link
Copy Markdown

Thanks for this! I'll take a look into it. I appreciate your help with all of this.

@ryan-caesar-ramos

Copy link
Copy Markdown

Actually I was trying to use mT0 XXL for a different research project a while back, but I had difficulties just trying to load it into memory. But thanks for prompting me to take another look; I was reviewing my notebook to try to refresh my memory on what I tried and I'm only now seeing I never set load_in_8bit to True, so I'll try again soon. Thanks again!

@younesbelkada

younesbelkada commented Feb 7, 2023

Copy link
Copy Markdown
Contributor Author

FYI here is what the model card of mt0-xxl states:

Prompt Engineering: The performance may vary depending on the prompt. For BLOOMZ models, we recommend making it very clear when the input stops to avoid the model trying to continue it. 
For example, the prompt "Translate to English: Je t'aime" without the full stop (.) at the end, may result in the model trying to continue the French sentence.
 Better prompts are e.g. "Translate to English: Je t'aime.", "Translate to English: Je t'aime. Translation:" "What is "Je t'aime." in English?", where it is clear for the model when it should answer. 
Further, we recommend providing the model as much context as possible. 
For example, if you want it to answer in Telugu, then tell the model, e.g. "Explain in a sentence in Telugu what is backpropagation in neural networks.".

@younesbelkada

Copy link
Copy Markdown
Contributor Author

cc @Muennighoff if you have any idea what might be wrong here 🙏

@alexconstant9108

Copy link
Copy Markdown

@younesbelkada BTW did you read my comment above with some questions / suggestions? What do you think?
#20281 (comment)

@Muennighoff

Copy link
Copy Markdown
Contributor

@alexconstant9108 Do you also get the same pad output for mT0 in FP32?

@Muennighoff

Copy link
Copy Markdown
Contributor

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

checkpoint = "bigscience/mt0-xxl"

tokenizer = AutoTokenizer.from_pretrained(checkpoint) model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto")

inputs = tokenizer.encode("Translate to English: Je t’aime.", return_tensors="pt").to("cuda") outputs = model.generate(inputs)

That's very odd - Do you get the same with mt0-small? Here's what I get:

!pip install -q transformers accelerate
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
checkpoint = "bigscience/mt0-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto", offload_folder="./")
inputs = tokenizer.encode("Translate to English: Je t’aime.", return_tensors="pt").to("cuda")
outputs = model.generate(inputs)
print(tokenizer.decode(outputs[0]))

<pad> I love you.</s>

@ryan-caesar-ramos

Copy link
Copy Markdown

@alexconstant9108 Just wanted to let you know I'm still having trouble loading mT0 XXL into memory. Maybe the shard sizes are too big? Not sure; sorry I can't verify your results

@Muennighoff

Copy link
Copy Markdown
Contributor

Yeah, mt0-small seems to work fine: <pad> I love you.</s> I will check the hashes of the downloaded weights of mt0-xxl when I have some time

I'm getting the same with mt0-xxl:

Python 3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
>>> checkpoint = "bigscience/mt0-xxl"
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto", offload_folder="./")
Downloading (…)model.bin.index.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50.7k/50.7k [00:00<00:00, 916kB/s]
Downloading (…)00001-of-00006.bin";: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.94G/9.94G [00:52<00:00, 191MB/s]
Downloading (…)00002-of-00006.bin";: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.87G/9.87G [04:33<00:00, 36.1MB/s]
Downloading (…)00003-of-00006.bin";: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.87G/9.87G [00:50<00:00, 194MB/s]
Downloading (…)00004-of-00006.bin";: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10.0G/10.0G [00:51<00:00, 195MB/s]
Downloading (…)00005-of-00006.bin";: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10.0G/10.0G [00:52<00:00, 191MB/s]
Downloading (…)00006-of-00006.bin";: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6.11G/6.11G [00:30<00:00, 198MB/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:35<00:00,  5.96s/it]
>>> inputs = tokenizer.encode("Translate to English: Je t’aime.", return_tensors="pt").to("cuda")
>>> outputs = model.generate(inputs)
>>> print(tokenizer.decode(outputs[0]))
<pad> I love you.</s>

Can you double-check your environment?
Mine looks like:

accelerate-0.16.0
transformers-4.26.0
tokenizers-0.13.2
pytorch-1.13.1
CUDA Version: 11.6 (A100 80GB)

@Muennighoff

Copy link
Copy Markdown
Contributor

It looks like your pytorch_model-00003-of-00006.bin has a different sha256 than the uploaded one: https://huggingface.co/bigscience/mt0-xxl/blob/main/pytorch_model-00003-of-00006.bin

Mine & the uploaded ones are:

295a276775d79359cfd243bd93c9e2c408a8e33718e5bee1d05625f026af6175
c21533a6182886bec48cd0190952b3c5e71224873234135c2754f7c81d02ac82
62cc874eb7f5cfa6fcbde4a19bab7de1f7bf8b47f0f01c45713927115c85a153
36b2a5945f7c037b99eaf5ed891fc158b23a791b92042861a5298b0c8ec224be
3f769732a1c4ba3a9cbd9ea1c2701ade3cdf2a35f73e75ac77d0c26788a5d88f
92679f99746d0e1082d7407091cb7f2a588d49b9bf13724f706e8912f86c5786

@ryan-caesar-ramos

Copy link
Copy Markdown

After running a few simple prompts, I can't see any difference in int8 output when compared to fp32 or fp16. If you have a more sophisticated prompt you wish to try, lmk

Sorry to bump this thread @alexconstant9108 but would you mind running this prompt?:

Answer the following question by reasoning step-by-step.\nThe cafeteria had 23 apples. If they used 20 for lunch and bought 6 more, how many apples do they have?

Splitting the model strictly between GPU and CPU (no disk involved) seemed to fix my problems in terms of getting the wo layers in fp32. While I was able to get the expected output for translate English to German: How old are you?, my output for the above was unfortunately

<pad> The cafeteria has 23 - 20 = 3 apples left. They have 3 + 6 = 7 apples. Therefore, the answer is 7.</s>

@younesbelkada

Copy link
Copy Markdown
Contributor Author

Hi @alexconstant9108 ,

Thanks for your interest in this feature. I propose to slightly refactor the API in #21579 and enable the feature you have asked for! Feel free to share any thoughts there

@alexconstant9108

Copy link
Copy Markdown

@ryan-caesar-ramos after loading all weights as fp32, I also get the same output as you:
print(tokenizer.decode(outputs[0])) <pad> The cafeteria has 23 - 20 = 3 apples left. They have 3 + 6 = 7 apples. Therefore, the answer is 7.</s>

LLMs generally suck even at basic Math(arithmetic included) so the above error is not surprising especially for a small model like flan-t5-xxl. I think the only remedy for this issue is using a different architecture (not transformer based) or adding the ability to the model to call external tools e.g. a calculator app. I haven't tried yet but I suspect that even ChatGPT will mess up a puzzle like the above

@alexconstant9108

Copy link
Copy Markdown

@ryan-caesar-ramos you may want to give FlexGen a try when loading big models: https://github.com/FMInference/FlexGen

@ryan-caesar-ramos

Copy link
Copy Markdown

Thanks! Will check it out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Tracker] [bnb] Supporting device_map containing GPU and CPU devices

7 participants