Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

RFC: What do you think about TRAX? #1478

Open
lukaszkaiser opened this issue Mar 7, 2019 · 12 comments
Open

RFC: What do you think about TRAX? #1478

lukaszkaiser opened this issue Mar 7, 2019 · 12 comments
Labels

Comments

@lukaszkaiser
Copy link
Contributor

We're thinking how to make the next T2T much better. One thing that came up is using JAX and gin config and we've prototyped TRAX:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/trax

If you're interested, please take a look. Run one of the examples, try to change something, train on your data, make your own model, tweak things. If you had trouble doing things in T2T before, let us know if that looks like it'd help!

TRAX is very early and it lacks features and has bug - we know that and we'll be correcting the small things as we go. But we'd love to think about higher-level things that may be easier to address at this stage, before the design stabilizes. Especially if you had trouble doing things in T2T before, let us know if that looks like it'd help!

@f-lng
Copy link

f-lng commented Mar 7, 2019

Hello Lukasz,

Current problem #1 - TF1.0
I simply do not like the graph API of TF1, and I think any move away from it is a good thing, so TF2.0-T2T and JAX-T2T should both do the trick. As T2T is pretty complex and a lot of questions need to be answered by going through the code, the hard-to-follow graph execution did not really help.

Current problem #2 - Estimators
Being able to just load and use a T2T model within an application, without using a complex tf-serving approach, was also really hard to do. The Estimator-Interface is the main problem here, as it is reloading the whole model for each predict call. I really hope TF2.0-T2T and JAX-T2T will allow to work around Estimators more seamlessly.

TRAX - What I like
Well, the whole code looks really clean and easy to follow. In theory, it should solve all my issues with T2T

TRAX - What I dislike
I do not like the idea of using yet another framework for T2T. While JAX seems cool and gains traction, wouldn't it be a better idea to base the rewrite on TF2.0?
I understand that a seamless transition from current T2T to a TF2.0 backed version is not possible. So rewriting big parts of the framework might be needed, perhaps even a start from scratch.
But then, if you plan to rewrite a lot of the code anyway (for the TRAX version), why not bundle the resources, start from scratch and use an eager TF2.0 approach?

I think in the long term T2T will not benefit from multiple branches that need to be maintained separately. Especially if one of them is still heavily relying on TF1.0 and graph mode, and the other is relying on a quite small framework like JAX.

Thoughts?

@etragas-fathom
Copy link
Contributor

I've used Autograd in the past for some research, and found it to be really inefficient and ended up walking away from it to use pytorch, which had huge speed ups.

As a more concrete example, at one point I encountered the backwards pass of autograd consuming 20gb of memory, which ended up being fixed with PR.

I mention the above to add support to @f-lng's comment, that adding another framework for t2t would make maintaining t2t harder, since bugs would inevitably arise that stem from the underlying framework.

OTOH, I think it's important to note that T2T's mission in part is to make deep learning more accessible. I'm not familiar enough with TF2.0 to understand if migrating to it would help this goal. What are your thoughts on that @lukaszkaiser / @f-lng?

@f-lng
Copy link

f-lng commented Mar 9, 2019

@etragas-fathom I do think that TF2.0 would serve the mission of making it more accessible, because model execution / codeflow is simplified with the eager programming style and model creation is simplified with the keras API.

I might be biased here, as I only touch TF if I can not get around it (e.g. for T2T), and always found Keras to be a pretty good alternative, especially if you are doing engineering, not research.

I think a proper TF2.0 based T2T would be perfect, but if TRAX is the way to go, it might be a better approach to simply drop the TF backend alltogether and focus on the TRAX version.

@lukaszkaiser
Copy link
Contributor Author

The problem with TF 2.0 at least for now is that when you want speed (use @tf.function or functional Keras mode) you're back in TF 1.0 graph-mode land. With shape bugs, large stack traces and all, and it feels as hard to debug as TF 1.0 or harder.

With JAX the speed problem of autograd is gone (I'm just training a Transformer LM as fast as T2T and a Resnet just a little slower). But other bugs may re-surface with more use, we'll need to see, I guess.

Please keep adding comments so we know what to look out for!

@f-lng
Copy link

f-lng commented Mar 22, 2019

Well, if you guys say TF 2.0 is not a good fit for T2T (yet?), then I guess TRAX is a good alternative, as the code looks very clean and the rewrite is easy to follow :-)

Btw, is there already a beam search decoding implemented and documented? I would love to give it a try.

@JosephRedfern
Copy link

Perhaps not a hugely insightful comment, it's a shame that this makes installing recent versions of t2t non-trivial under Windows (see #1507).

@moskomule
Copy link

I would like TRAX to be independent a repo/package. I don't want to install tensorflow if possible, but T2T depends on it.

@bzz
Copy link

bzz commented Oct 28, 2019

As on 6c7c601 it seems like TRAX has been moved to it's own repo although it's not clear to which one though https://github.com/google/trax

@afrozenator
Copy link
Contributor

afrozenator commented Oct 28, 2019 via email

@AranKomat
Copy link

@afrozenator @lukaszkaiser

Could you tell me the advantage of JAX over PyTorch? Does JAX provide any tool with which Transformer and Reformer becomes faster for the reason other than the fact that JAX probably has better support of TPU?

@JonathanSum
Copy link

JonathanSum commented Aug 11, 2020

@lukaszkaiser
I am not sure whether or not I am allowed to post it here. I feel what you said is totally correct. To me, Keras and TF are really complicated and a mess in their Graph API and more.
I hope Trax will be as readable as Pytorch on building a more complicated model.

Here is an example of Pytorch on Unet, which is readable and beautiful. I am no researcher or engineer. I am just nobody, but I am able to build GAN. I am also able to build a self-supervised learning model to color my favorite Japanese Anime picture without having GAN's unstable problem. But I can only do it in Pytorch because it is so easy, readable, and beautiful. It is so friendly to open source people.

https://gist.github.com/Hsankesara/e3b064ff47d538052e059084b8d4df9f#file-unet-py
image

@lkluo
Copy link

lkluo commented Feb 3, 2021

@lukaszkaiser Trax takes extremely long time to be imported, which makes it very uncomfortable to debug.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

10 participants