Skip to content

Commit

Permalink
Add use_8bit Adam optimizer config to the UI.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed May 21, 2024
1 parent 654bba1 commit cd85bb2
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/invoke_training/ui/config_groups/optimizer_config_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
class AdamOptimizerConfigGroup(UIConfigElement):
def __init__(self):
with gr.Tab("Core"):
self.learning_rate = gr.Number(
label="Learning Rate",
info="Initial learning rate to use (after the potential warmup period). Note that in some training "
"pipelines this can be overriden for a specific group of params.",
interactive=True,
)
with gr.Row():
self.learning_rate = gr.Number(
label="Learning Rate",
info="Initial learning rate to use (after the potential warmup period). Note that in some training "
"pipelines this can be overriden for a specific group of params.",
interactive=True,
)
self.use_8bit = gr.Checkbox(
label="Use 8-bit",
info="Use 8-bit Adam optimizer to reduce VRAM requirements. (Requires bitsandbytes.)",
interactive=True,
)
with gr.Tab("Advanced"):
with gr.Row():
self.beta1 = gr.Number(label="beta1", interactive=True)
Expand All @@ -32,6 +38,7 @@ def update_ui_components_with_config_data(self, config: AdamOptimizerConfig) ->
self.beta2: config.beta2,
self.weight_decay: config.weight_decay,
self.epsilon: config.epsilon,
self.use_8bit: config.use_8bit,
}

def update_config_with_ui_component_data(
Expand All @@ -45,6 +52,7 @@ def update_config_with_ui_component_data(
beta2=ui_data.pop(self.beta2),
weight_decay=ui_data.pop(self.weight_decay),
epsilon=ui_data.pop(self.epsilon),
use_8bit=ui_data.pop(self.use_8bit),
)


Expand Down

0 comments on commit cd85bb2

Please sign in to comment.