Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Dynamically Specifying Layers in enable_lora Method #1821

Open
anirudhr20 opened this issue Sep 11, 2024 · 2 comments
Open
Assignees
Labels

Comments

@anirudhr20
Copy link

Is your feature request related to a problem? Please describe.
The current implementation of the enable_lora method in Keras only allows LoRA to be enabled for a predefined set of layers ("query_dense", "value_dense", "query", "value"). This is limiting because it doesn't provide flexibility for users who want to apply LoRA to other layers.

  • In many custom architectures, users may want to enable LoRA for different layers that do not follow the predefined naming convention, which leads to unnecessary modifications to the codebase.
  • Additionally, enabling LoRA on multiple layers can significantly increase the number of trainable parameters in a model, which could lead to memory issues, especially on devices with limited resources (such as GPUs with lower VRAM). Users may want finer control over which layers have LoRA enabled to avoid memory bottlenecks during training.

Describe the solution you'd like
I propose modifying the enable_lora method to accept a list of custom layer names as a parameter. This way, users can dynamically specify which layers they want to apply LoRA to. The solution could involve adding an optional target_names argument to the method that defaults to the current predefined layers but can be overridden by users. For example:

def enable_lora(self, rank, target_names=None):
    target_names = target_names if target_name else ["query_dense", "value_dense", "query", "value"]

This enhancement would allow more flexibility in applying LoRA to various architectures without needing to modify the core method.

Describe alternatives you've considered
Some of the alternatives thought about:-
1. Inheritance and Override: Another alternative is to subclass the relevant Keras model classes and override the enable_lora method with a custom implementation. While this avoids direct modification of the library, it can still introduce complexity and redundancy for a common use case.
2. Defining Custom Lora layer and adding it manually to the model layers: Another alternative is to manually define a custom LoRA layer and adapt it to different parts of the model. I have personally experimented with this approach, and while it works, it's a tedious process. You need to manually implement the LoRA logic, inject it into the model.

Additional context
This feature would provide better control for users who are training models on resource-constrained environments and make it easier to integrate LoRA into custom models, especially when experimenting.

Note: I would like to take up this issue and implement the proposed changes. If this feature request is accepted, I will raise a pull request (PR) with the necessary modifications.

@VarunS1997
Copy link
Collaborator

Have you considered calling enable_lora on a per-layer basis? Is there a reason that approach wouldn't work for your case?

@anirudhr20
Copy link
Author

Yes, I have considered calling enable_lora on a per-layer basis. While this approach is certainly possible, I believe this may not be the most efficient approach for certain use-cases.

For instance, large models with many layers, manually identifying and calling enable_lora for each individual layer can be time-consuming and error-prone. Having the ability to dynamically pass a list of layer names to the enable_lora method would make the code cleaner and easier to maintain.

Additionally, I'd like to highlight that Hugging Face also offers LoRA with dynamic target layers. For example, they allow users to specify target modules dynamically through their LoraConfig:

target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"]
config = LoraConfig(
    r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
)

This dynamic specification provides flexibility and allows users to enable LoRA on multiple layers with minimal effort. Implementing something similar in Keras would ensure consistent usage across libraries, making it easier for users who switch between frameworks or work with different libraries in the same pipeline.

While calling enable_lora per layer works for small-scale use cases, the enhancement I am proposing is aimed at simplifying this process for larger and more complex models, while aligning Keras more closely with other widely-used libraries like Hugging Face.

Please let me know if this approach makes sense, or if there are any additional considerations I should take into account.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants