Skip to content

Commit

Permalink
fixing esmfold_advanced to use pytorch2/cuda12
Browse files Browse the repository at this point in the history
  • Loading branch information
sokrypton committed Dec 28, 2023
1 parent 740d844 commit e14313f
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions beta/ESMFold_advanced.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@
" os.system(\"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/esmfold.model &\")\n",
"\n",
" # install libs\n",
" os.system(\"pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol\")\n",
" print(\"installing libs...\")\n",
" os.system(\"pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol modelcif\")\n",
" os.system(\"pip install -q git+https://github.com/NVIDIA/dllogger.git\")\n",
"\n",
" print(\"installing openfold...\")\n",
" # install openfold\n",
" commit = \"6908936b68ae89f67755240e2f588c09ec31d4c8\"\n",
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
" os.system(f\"pip install -q git+https://github.com/sokrypton/openfold.git\")\n",
"\n",
" print(\"installing esmfold...\")\n",
" # install esmfold\n",
" os.system(f\"pip install -q git+https://github.com/sokrypton/esm.git@beta\")\n",
"\n",
Expand Down Expand Up @@ -87,7 +89,7 @@
"def parse_output(output):\n",
" pae = (output[\"aligned_confidence_probs\"][0] * np.arange(64)).mean(-1) * 31\n",
" plddt = output[\"plddt\"][0,:,1]\n",
" \n",
"\n",
" bins = np.append(0,np.linspace(2.3125,21.6875,63))\n",
" sm_contacts = softmax(output[\"distogram_logits\"],-1)[0]\n",
" sm_contacts = sm_contacts[...,bins<8].sum(-1)\n",
Expand Down Expand Up @@ -129,7 +131,7 @@
"sequence = \":\".join([sequence] * copies)\n",
"\n",
"#@markdown **sampling options (experimental)**\n",
"#@markdown - Samples are generated via random masking (defined by `masking_rate`) \n",
"#@markdown - Samples are generated via random masking (defined by `masking_rate`)\n",
"#@markdown of input sequence (stochastic_mode=\"LM\") and/or via dropout within structure module (stochastic_mode=\"SM\").\n",
"samples = None #@param [\"None\", \"1\", \"4\", \"8\", \"16\", \"32\", \"64\"] {type:\"raw\"}\n",
"masking_rate = 0.15 #@param {type:\"number\"}\n",
Expand Down Expand Up @@ -180,7 +182,7 @@
" residue_index_offset=512,\n",
" mask_rate=mask_rate,\n",
" return_contacts=get_LM_contacts)\n",
" \n",
"\n",
" pdb_str = model.output_to_pdb(output)[0]\n",
" output = tree_map(lambda x: x.cpu().numpy(), output)\n",
" ptm = output[\"ptm\"][0]\n",
Expand Down Expand Up @@ -226,7 +228,7 @@
" size=(800,480), hbondCutoff=4.0,\n",
" Ls=None,\n",
" animate=False):\n",
" \n",
"\n",
" if chains is None:\n",
" chains = 1 if Ls is None else len(Ls)\n",
" view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])\n",
Expand All @@ -248,7 +250,7 @@
" view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
" {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" if show_mainchains:\n",
" BB = ['C','O','N','CA']\n",
" view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
Expand Down Expand Up @@ -368,7 +370,7 @@
"dpi = 100#@param {type:\"integer\"}\n",
"color_by_plddt = True #@param {type:\"boolean\"}\n",
"use_pca = True\n",
"cycle = True \n",
"cycle = True\n",
"\n",
"import matplotlib\n",
"from matplotlib import animation\n",
Expand All @@ -387,7 +389,7 @@
" p = P_ - P_.mean(0,keepdims=True)\n",
" q = Q_ - Q_.mean(0,keepdims=True)\n",
" return ((P - P_.mean(0,keepdims=True)) @ cf.kabsch(p,q)) + Q_.mean(0,keepdims=True)\n",
" \n",
"\n",
" pos = positions[ref] - positions[ref].mean(0,keepdims=True)\n",
" best_2D_view = pos @ cf.kabsch(pos,pos,return_v=True)\n",
"\n",
Expand All @@ -396,7 +398,7 @@
" new_positions.append(align(positions[i],best_2D_view))\n",
" return np.asarray(new_positions)\n",
"\n",
" # align to reference \n",
" # align to reference\n",
" pos = ca_align_to_last(xyz, ref)\n",
"\n",
" fig, (ax1) = plt.subplots(1)\n",
Expand All @@ -415,14 +417,14 @@
" ims=[]\n",
" for l,pos_,plddt_ in zip(labels,pos,plddt):\n",
" if color_by_plddt:\n",
" img = cf.plot_pseudo_3D(pos_, c=plddt_, cmin=50, cmax=90, line_w=line_w, ax=ax1) \n",
" img = cf.plot_pseudo_3D(pos_, c=plddt_, cmin=50, cmax=90, line_w=line_w, ax=ax1)\n",
" elif Ls is None or len(Ls) == 1:\n",
" img = cf.plot_pseudo_3D(pos_, ax=ax1, line_w=line_w)\n",
" else:\n",
" c = np.concatenate([[n]*L for n,L in enumerate(Ls)])\n",
" img = cf.plot_pseudo_3D(pos_, c=c, cmap=cf.pymol_cmap, cmin=0, cmax=39, line_w=line_w, ax=ax1)\n",
" ims.append([cf.add_text(f\"{l}\", ax1),img])\n",
" \n",
"\n",
" ani = animation.ArtistAnimation(fig, ims, blit=True, interval=120)\n",
" plt.close()\n",
" return ani.to_html5_video()\n",
Expand Down Expand Up @@ -456,7 +458,6 @@
"accelerator": "GPU",
"colab": {
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"gpuClass": "standard",
Expand Down

0 comments on commit e14313f

Please sign in to comment.