Skip to content

Loading FlaxHybridCLIP trained model #22673

Description

@alhuri

System Info

  • transformers version: 4.27.4
  • Platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.13.4
  • PyTorch version (GPU?): 1.9.0+cpu (False)
  • Tensorflow version (GPU?): 2.9.1 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.6.8 (cpu)
  • Jax version: 0.4.8
  • JaxLib version: 0.4.7
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Models: FlaxHybridCLIP

Who can help?

@sanchit-gandhi @patrickvonplaten, @patil-suraj

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm wondering how to import a trained FlaxHybridCLIP model from a folder that contains the following files

  • config.json
  • flax_model.msgpack

I attempted to load it using the below:

 if args.run_from_checkpoint is not None:
        with open(f"{args.run_from_checkpoint}/config.json", "r") as fp:
            config_dict = json.load(fp)
        config_dict["vision_config"]["model_type"] = "clip"
        config = HybridCLIPConfig(**config_dict)
        model = FlaxHybridCLIP.from_pretrained(
            args.run_from_checkpoint,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
            config=config,
            freeze_backbones=args.freeze_backbones
        )

But, I encountered the following error:

`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.
`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.
loading weights file freeze/18/flax_model.msgpack
Traceback (most recent call last):
  File "run_hybrid_clip.py", line 831, in <module>
    main()
  File "run_hybrid_clip.py", line 528, in main
    model = FlaxHybridCLIP.from_pretrained(
  File "/home/ubuntu/.local/lib/python3.8/site-packages/transformers/modeling_flax_utils.py", line 807, in from_pretrained
    model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
  File "/home/ubuntu/modeling_hybrid_clip.py", line 148, in __init__
    module = self.module_class(config=config, dtype=dtype, **kwargs)
TypeError: __init__() got an unexpected keyword argument '_do_init'

I used the modified Italian hybrid CLIP scripts here

Expected behavior

to load successfully and fine-tune with unfrozen backbone

Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions