-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Mamba Speed #129
Comments
It is expected that Mambular is slower than e.g. FT-Transformer, especially for datasets with a lot of features, since training time increases linearly with sequence length (number of features). However, we experienced this by a factor of 2.5-3 while being more memory efficient than FT-Transformer. Could you provide a minimal code example with simulated data where you experience similar training times? Then we can verify. |
Hello AnFreTh, Thank you for your reply. Based on your suggestion, I have prepared a minimal code example for you to review. In my current framework, I am using Mambular as the tabular encoder within a table-image contrastive learning setting. I defined a
For simplicity, the provided code example only uses simulated tabular data. This dataset has 8139 samples in total, with 6530 samples split between the training and validation sets. Each sample consists of 423 numerical features only, with no categorical features. When running this simplified code (with a batch size of 16), training the Mambular encoder takes approximately 2.5 hours per epoch, while using the FT-Transformer encoder takes around 15 seconds per epoch, and using ResNet as the encoder takes about 7 seconds per epoch. I have attached the code example for your review. Please let me know if anything else is needed to further investigate the issue. Thank you again for your help!
|
I could not recreate the extreme differences you reported, but still using default Mambular was 10x slower than FTTransformer for this specific setup. We will update the current Mambablock implementation to increase speed.
|
Thank you for taking the time to investigate the issue. I will try these and look forward to your updates. Thanks again for your help and support! |
If you experiment further you could -instead of the python mamba implementation from Mambular- try out the original Mamba implementation: https://pypi.org/project/mamba-ssm/ |
Since a true fix made available in the package might take some time, there are two fixes you could try to solve the issue faster: First, try the TabulaRNN(model_type="LSTM", d_conv=16) but from the develop branch. So install the package via: pip install git+https://github.com/basf/mamba-tabular.git@develop This should be more memory efficient and similar in speed to the FT-Transformer. I would advise increase the kernel size of the convolution, given your large number of variables. Second, depending on your ressources, you cuold try to leverage the original mamba implementation. This can be tricky, since not all systems/GPUs are supported. pip install mamba-ssm Then simply import mamba and switch it with the pytorch version from Mambular. from mamba_ssm import Mamba #You could also try Mamba2
# in your class switch out the Mamba from Mambular with
self.mamba = nn.ModuleList()
for _ in range(n_layers):
self.mamba.append(Mamba(
d_model=self.hparams.get("d_model", config.d_model),
expand_factor=self.hparams.get("expand_factor", config.expand_factor),
d_conv=self.hparams.get("d_conv", config.d_conv),
)
) See: https://github.com/state-spaces/mamba for further details on the original implementation. |
Thank you for your advice! I have tried the original Mamba implementation as suggested, and the training speed has significantly improved. It is now similar in speed to FT-Transformer and ResNet :). |
Hello AnFreTh,
Thank you for your work on this project. I am currently using Mambular to process tabular data, but I am experiencing very slow training speeds. On average, each epoch is taking around 80 minutes to complete.
Here are the details of my setup:
For comparison, when I use ResNet or FT-Transformer as tabular encoder with the same setup, the training speed is approximately 25 seconds per epoch, which is significantly faster. Is it expected that Mambular would be much slower than ResNet or FT Transformer? Or could this be an issue with my configuration or code?
I would appreciate any insight you could provide. Is there any known issue, or something I can adjust in my configuration to improve the speed?
Please let me know if you need additional information to help diagnose the problem.
Thank you for your time and assistance!
The text was updated successfully, but these errors were encountered: