From 34921debbc05f4722af8c9b5a544582d12f84e8c Mon Sep 17 00:00:00 2001 From: YoshitakaMo Date: Sun, 2 Jul 2023 02:09:42 +0900 Subject: [PATCH] skip templatesearch when custom_temp_path provided --- colabfold/batch.py | 129 +++++++++++++++++++++++++-------------------- 1 file changed, 71 insertions(+), 58 deletions(-) diff --git a/colabfold/batch.py b/colabfold/batch.py index 0648e143..6e31403b 100644 --- a/colabfold/batch.py +++ b/colabfold/batch.py @@ -299,11 +299,11 @@ def relax_me(pdb_filename=None, pdb_lines=None, pdb_obj=None, use_gpu=False): from alphafold.common import residue_constants from alphafold.relax import relax - if pdb_obj is None: + if pdb_obj is None: if pdb_lines is None: pdb_lines = Path(pdb_filename).read_text() pdb_obj = protein.from_pdb_string(pdb_lines) - + amber_relaxer = relax.AmberRelaxation( max_iterations=0, tolerance=2.39, @@ -311,7 +311,7 @@ def relax_me(pdb_filename=None, pdb_lines=None, pdb_obj=None, use_gpu=False): exclude_residues=[], max_outer_iterations=3, use_gpu=use_gpu) - + relaxed_pdb_lines, _, _ = amber_relaxer.process(prot=pdb_obj) return relaxed_pdb_lines @@ -321,7 +321,7 @@ def __init__(self, prefix: str, result_dir: Path): self.result_dir = result_dir self.tag = None self.files = {} - + def get(self, x: str, ext:str) -> Path: if self.tag not in self.files: self.files[self.tag] = [] @@ -366,13 +366,13 @@ def predict_structure( # iterate through random seeds for seed_num, seed in enumerate(range(random_seed, random_seed+num_seeds)): - + # iterate through models for model_num, (model_name, model_runner, params) in enumerate(model_runner_and_params): - + # swap params to avoid recompiling model_runner.params = params - + ######################### # process input features ######################### @@ -383,24 +383,24 @@ def predict_structure( input_features["asym_id"] = input_features["asym_id"] - input_features["asym_id"][...,0] else: if model_num == 0: - input_features = model_runner.process_features(feature_dict, random_seed=seed) + input_features = model_runner.process_features(feature_dict, random_seed=seed) r = input_features["aatype"].shape[0] input_features["asym_id"] = np.tile(feature_dict["asym_id"],r).reshape(r,-1) if seq_len < pad_len: - input_features = pad_input(input_features, model_runner, + input_features = pad_input(input_features, model_runner, model_name, pad_len, use_templates) logger.info(f"Padding length to {pad_len}") - + tag = f"{model_type}_{model_name}_seed_{seed:03d}" model_names.append(tag) files.set_tag(tag) - + ######################## # predict ######################## start = time.time() - + # monitor intermediate results def callback(result, recycles): if recycles == 0: result.pop("tol",None) @@ -419,12 +419,12 @@ def callback(result, recycles): result=result, b_factors=b_factors, remove_leading_feature_dimension=("multimer" not in model_type)) files.get("unrelaxed",f"r{recycles}.pdb").write_text(protein.to_pdb(unrelaxed_protein)) - + if save_all: with files.get("all",f"r{recycles}.pickle").open("wb") as handle: pickle.dump(result, handle) del unrelaxed_protein - + return_representations = save_all or save_single_representations or save_pair_representations # predict @@ -439,9 +439,9 @@ def callback(result, recycles): ######################## # parse results ######################## - + # summary metrics - mean_scores.append(result["ranking_confidence"]) + mean_scores.append(result["ranking_confidence"]) if recycles == 0: result.pop("tol",None) if not is_complex: result.pop("iptm",None) print_line = "" @@ -469,7 +469,7 @@ def callback(result, recycles): ######################### # save results - ######################### + ######################### # save pdb protein_lines = protein.to_pdb(unrelaxed_protein) @@ -498,12 +498,12 @@ def callback(result, recycles): del pae del plddt json.dump(scores, handle) - + del result, unrelaxed_protein # early stop criteria fulfilled if mean_scores[-1] > stop_at_score: break - + # early stop criteria fulfilled if mean_scores[-1] > stop_at_score: break @@ -514,7 +514,7 @@ def callback(result, recycles): ################################################### # rerank models based on predicted confidence ################################################### - + rank, metric = [],[] result_files = [] logger.info(f"reranking models by '{rank_by}' metric") @@ -527,7 +527,7 @@ def callback(result, recycles): if n < num_relax: start = time.time() pdb_lines = relax_me(pdb_lines=unrelaxed_pdb_lines[key], use_gpu=use_gpu_relax) - files.get("relaxed","pdb").write_text(pdb_lines) + files.get("relaxed","pdb").write_text(pdb_lines) logger.info(f"Relaxation took {(time.time() - start):.1f}s") # rename files to include rank @@ -538,7 +538,7 @@ def callback(result, recycles): new_file = result_dir.joinpath(f"{prefix}_{x}_{new_tag}.{ext}") file.rename(new_file) result_files.append(new_file) - + return {"rank":rank, "metric":metric, "result_files":result_files} @@ -649,10 +649,10 @@ def get_queries( # sort by seq. len if sort_queries_by == "length": queries.sort(key=lambda t: len("".join(t[1]))) - + elif sort_queries_by == "random": random.shuffle(queries) - + is_complex = False for job_number, (raw_jobname, query_sequence, a3m_lines) in enumerate(queries): if isinstance(query_sequence, list): @@ -719,6 +719,7 @@ def pad_sequences( def get_msa_and_templates( jobname: str, query_sequences: Union[str, List[str]], + a3m_lines: Optional[List[str]], result_dir: Path, msa_mode: str, use_templates: bool, @@ -749,17 +750,29 @@ def get_msa_and_templates( # get template features template_features = [] if use_templates: - a3m_lines_mmseqs2, template_paths = run_mmseqs2( - query_seqs_unique, - str(result_dir.joinpath(jobname)), - use_env, - use_templates=True, - host_url=host_url, - ) + # Skip template search when custom_template_path is provided if custom_template_path is not None: + if a3m_lines is None: + a3m_lines_mmseqs2 = run_mmseqs2( + query_seqs_unique, + str(result_dir.joinpath(jobname)), + use_env, + use_templates=False, + host_url=host_url, + ) + else: + a3m_lines_mmseqs2 = a3m_lines template_paths = {} for index in range(0, len(query_seqs_unique)): template_paths[index] = custom_template_path + else: + a3m_lines_mmseqs2, template_paths = run_mmseqs2( + query_seqs_unique, + str(result_dir.joinpath(jobname)), + use_env, + use_templates=True, + host_url=host_url, + ) if template_paths is None: logger.info("No template detected") for index in range(0, len(query_seqs_unique)): @@ -966,7 +979,7 @@ def generate_input_feature( # bugfix a3m_lines = f">0\n{full_sequence}\n" - a3m_lines += pair_msa(query_seqs_unique, query_seqs_cardinality, paired_msa, unpaired_msa) + a3m_lines += pair_msa(query_seqs_unique, query_seqs_cardinality, paired_msa, unpaired_msa) input_feature = build_monomer_feature(full_sequence, a3m_lines, mk_mock_template(full_sequence)) input_feature["residue_index"] = np.concatenate([np.arange(L) for L in Ls]) @@ -987,7 +1000,7 @@ def generate_input_feature( chain_cnt = 0 # for each unique sequence for sequence_index, sequence in enumerate(query_seqs_unique): - + # get unpaired msa if unpaired_msa is None: input_msa = f">{101 + sequence_index}\n{sequence}" @@ -1243,7 +1256,7 @@ def run( if max_msa is not None: max_seq, max_extra_seq = [int(x) for x in max_msa.split(":")] - if kwargs.pop("use_amber", False) and num_relax == 0: + if kwargs.pop("use_amber", False) and num_relax == 0: num_relax = num_models * num_seeds if len(kwargs) > 0: @@ -1263,14 +1276,14 @@ def run( L = len("".join(query_sequence)) if L > max_len: max_len = L if N > max_num: max_num = N - + # get max sequences # 512 5120 = alphafold_ptm (models 1,3,4) # 512 1024 = alphafold_ptm (models 2,5) # 508 2048 = alphafold-multimer_v3 (models 1,2,3) # 508 1152 = alphafold-multimer_v3 (models 4,5) # 252 1152 = alphafold-multimer_v[1,2] - + set_if = lambda x,y: y if x is None else x if model_type in ["alphafold2_multimer_v1","alphafold2_multimer_v2"]: (max_seq, max_extra_seq) = (set_if(max_seq,252), set_if(max_extra_seq,1152)) @@ -1278,7 +1291,7 @@ def run( (max_seq, max_extra_seq) = (set_if(max_seq,508), set_if(max_extra_seq,2048)) else: (max_seq, max_extra_seq) = (set_if(max_seq,512), set_if(max_extra_seq,5120)) - + if msa_mode == "single_sequence": num_seqs = 1 if is_complex and "multimer" not in model_type: num_seqs += max_num @@ -1337,7 +1350,7 @@ def run( first_job = True for job_number, (raw_jobname, query_sequence, a3m_lines) in enumerate(queries): jobname = safe_filename(raw_jobname) - + ####################################### # check if job has already finished ####################################### @@ -1359,27 +1372,27 @@ def run( # generate MSA (a3m_lines) and templates ########################################### try: - if a3m_lines is None: + if a3m_lines is None: (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features) \ - = get_msa_and_templates(jobname, query_sequence, result_dir, msa_mode, use_templates, + = get_msa_and_templates(jobname, query_sequence, a3m_lines, result_dir, msa_mode, use_templates, custom_template_path, pair_mode, pairing_strategy, host_url) - - elif a3m_lines is not None: + + elif a3m_lines is not None: (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features) \ = unserialize_msa(a3m_lines, query_sequence) - if use_templates: + if use_templates: (_, _, _, _, template_features) \ - = get_msa_and_templates(jobname, query_seqs_unique, result_dir, 'single_sequence', use_templates, + = get_msa_and_templates(jobname, query_seqs_unique, a3m_lines, result_dir, 'single_sequence', use_templates, custom_template_path, pair_mode, pairing_strategy, host_url) - + # save a3m msa = msa_to_str(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) result_dir.joinpath(f"{jobname}.a3m").write_text(msa) - + except Exception as e: logger.exception(f"Could not get MSA/templates for {jobname}: {e}") continue - + ####################### # generate features ####################### @@ -1387,23 +1400,23 @@ def run( (feature_dict, domain_names) \ = generate_input_feature(query_seqs_unique, query_seqs_cardinality, unpaired_msa, paired_msa, template_features, is_complex, model_type, max_seq=max_seq) - + # to allow display of MSA info during colab/chimera run (thanks tomgoddard) if feature_dict_callback is not None: feature_dict_callback(feature_dict) - + except Exception as e: logger.exception(f"Could not generate input features {jobname}: {e}") continue - + ###################### # predict structures ###################### try: # get list of lengths - query_sequence_len_array = sum([[len(x)] * y + query_sequence_len_array = sum([[len(x)] * y for x,y in zip(query_seqs_unique, query_seqs_cardinality)],[]) - + # decide how much to pad (to avoid recompiling) if seq_len > pad_len: if isinstance(recompile_padding, float): @@ -1411,7 +1424,7 @@ def run( else: pad_len = seq_len + recompile_padding pad_len = min(pad_len, max_len) - + # prep model and params if first_job: # if one job input adjust max settings @@ -1423,7 +1436,7 @@ def run( num_seqs = int(len(feature_dict["msa"])) if use_templates: num_seqs += 4 - + # adjust max settings max_seq = min(num_seqs, max_seq) max_extra_seq = max(min(num_seqs - max_seq, max_extra_seq), 1) @@ -1498,7 +1511,7 @@ def run( scores_file = result_dir.joinpath(f"{jobname}_scores_{r}.json") with scores_file.open("r") as handle: scores.append(json.load(handle)) - + # write alphafold-db format (pAE) if "pae" in scores[0]: af_pae_file = result_dir.joinpath(f"{jobname}_predicted_aligned_error_v1.json") @@ -1535,7 +1548,7 @@ def run( with zipfile.ZipFile(result_zip, "w") as result_zip: for file in result_files: result_zip.write(file, arcname=file.name) - + # Delete only after the zip was successful, and also not the bibtex and config because we need those again for file in result_files[:-2]: file.unlink() @@ -1737,7 +1750,7 @@ def main(): ) args = parser.parse_args() - + # disable unified memory if args.disable_unified_memory: for k in ENV.keys(): @@ -1756,7 +1769,7 @@ def main(): queries, is_complex = get_queries(args.input, args.sort_queries_by) model_type = set_model_type(is_complex, args.model_type) - + download_alphafold_params(model_type, data_dir) if args.msa_mode != "single_sequence" and not args.templates: