From 0981cee1bb1d96c9b1c8388b6a5f0501c5cb4bd8 Mon Sep 17 00:00:00 2001 From: "xiang song(charlie.song)" Date: Wed, 5 Jun 2024 23:17:39 -0700 Subject: [PATCH] Update gsgnn_node_emb for multi-task learning (#860) *Issue #, if available:* #789 *Description of changes:* By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Xiang Song --- python/graphstorm/config/argument.py | 2 + .../run/gsgnn_emb/gsgnn_node_emb.py | 83 ++++++++++++------- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 2 + .../end2end-tests/graphstorm-mt/mgpu_test.sh | 15 ++++ 4 files changed, 71 insertions(+), 31 deletions(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 7f5992318a..95be1a935f 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -172,6 +172,8 @@ def __init__(self, cmd_args): # parse multi task learning config and save it into self._multi_tasks if multi_task_config is not None: self._parse_multi_tasks(multi_task_config) + else: + self._multi_tasks = None def set_attributes(self, configuration): """Set class attributes from 2nd level arguments in yaml config""" diff --git a/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py b/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py index 98735739f4..3d2e8938d5 100644 --- a/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py +++ b/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py @@ -20,13 +20,18 @@ from graphstorm.config import GSConfig from graphstorm.utils import rt_profiler, sys_tracker, get_device, use_wholegraph from graphstorm.dataloading import GSgnnData -from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, - BUILTIN_TASK_NODE_REGRESSION, - BUILTIN_TASK_EDGE_CLASSIFICATION, - BUILTIN_TASK_EDGE_REGRESSION, - BUILTIN_TASK_LINK_PREDICTION) +from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION, + GRAPHSTORM_MODEL_ALL_LAYERS, + GRAPHSTORM_MODEL_EMBED_LAYER, + GRAPHSTORM_MODEL_GNN_LAYER, + GRAPHSTORM_MODEL_DECODER_LAYER) from graphstorm.inference import GSgnnEmbGenInferer from graphstorm.utils import get_lm_ntypes +from graphstorm.model.multitask_gnn import GSgnnMultiTaskSharedEncoderModel def main(config_args): """ main function @@ -44,12 +49,14 @@ def main(config_args): if gs.get_rank() == 0: tracker.log_params(config.__dict__) - assert config.task_type in [BUILTIN_TASK_LINK_PREDICTION, - BUILTIN_TASK_NODE_REGRESSION, - BUILTIN_TASK_NODE_CLASSIFICATION, - BUILTIN_TASK_EDGE_CLASSIFICATION, - BUILTIN_TASK_EDGE_REGRESSION], \ - f"Not supported for task type: {config.task_type}" + if config.multi_tasks is None: + # if not multi-task, check task type + assert config.task_type in [BUILTIN_TASK_LINK_PREDICTION, + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_NODE_CLASSIFICATION, + BUILTIN_TASK_EDGE_CLASSIFICATION, + BUILTIN_TASK_EDGE_REGRESSION], \ + f"Not supported for task type: {config.task_type}" input_data = GSgnnData(config.part_config, node_feat_field=config.node_feat_name, @@ -63,14 +70,25 @@ def main(config_args): "restore model path cannot be none for gs_gen_node_embeddings" # load the model - if config.task_type == BUILTIN_TASK_LINK_PREDICTION: - model = gs.create_builtin_lp_gnn_model(input_data.g, config, train_task=False) - elif config.task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}: - model = gs.create_builtin_node_gnn_model(input_data.g, config, train_task=False) - elif config.task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}: - model = gs.create_builtin_edge_gnn_model(input_data.g, config, train_task=False) + if config.multi_tasks: + # Only support multi-task shared encoder model. + model = GSgnnMultiTaskSharedEncoderModel(config.alpha_l2norm) + gs.gsf.set_encoder(model, input_data.g, config, train_task=False) + assert config.restore_model_layers is not GRAPHSTORM_MODEL_ALL_LAYERS, \ + "When computing node embeddings with GSgnnMultiTaskSharedEncoderModel, " \ + "please set --restore-model-layers to " \ + f"{GRAPHSTORM_MODEL_EMBED_LAYER}, {GRAPHSTORM_MODEL_GNN_LAYER}." \ + f"Please do not include {GRAPHSTORM_MODEL_DECODER_LAYER}, " \ + f"but we get {config.restore_model_layers}" else: - raise TypeError("Not supported for task type: ", config.task_type) + if config.task_type == BUILTIN_TASK_LINK_PREDICTION: + model = gs.create_builtin_lp_gnn_model(input_data.g, config, train_task=False) + elif config.task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}: + model = gs.create_builtin_node_gnn_model(input_data.g, config, train_task=False) + elif config.task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}: + model = gs.create_builtin_edge_gnn_model(input_data.g, config, train_task=False) + else: + raise TypeError("Not supported for task type: ", config.task_type) model.restore_model(config.restore_model_path, model_layer_to_load=config.restore_model_layers) @@ -78,21 +96,24 @@ def main(config_args): emb_generator = GSgnnEmbGenInferer(model) emb_generator.setup_device(device=get_device()) - task_type = config.task_type - # infer ntypes must be sorted for node embedding saving - if task_type == BUILTIN_TASK_LINK_PREDICTION: + if config.multi_tasks: + # infer_ntypes = None means all node types. infer_ntypes = None - elif task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}: - # TODO(xiangsx): Support multi-task on multiple node types. - infer_ntypes = [config.target_ntype] - elif task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}: - infer_ntypes = set() - for etype in config.target_etype: - infer_ntypes.add(etype[0]) - infer_ntypes.add(etype[2]) - infer_ntypes = sorted(list(infer_ntypes)) else: - raise TypeError("Not supported for task type: ", task_type) + task_type = config.task_type + # infer ntypes must be sorted for node embedding saving + if task_type == BUILTIN_TASK_LINK_PREDICTION: + infer_ntypes = None + elif task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}: + infer_ntypes = [config.target_ntype] + elif task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}: + infer_ntypes = set() + for etype in config.target_etype: + infer_ntypes.add(etype[0]) + infer_ntypes.add(etype[2]) + infer_ntypes = sorted(list(infer_ntypes)) + else: + raise TypeError("Not supported for task type: ", task_type) emb_generator.infer(input_data, infer_ntypes, save_embed_path=config.save_embed_path, diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 93adf5144d..0d8c6d5c6d 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -416,6 +416,8 @@ def main(config_args): gs.gsf.set_encoder(model, train_data.g, config, train_task=True) tasks = config.multi_tasks + assert tasks is not None, \ + "The multi_task_learning configure block should not be empty." train_dataloaders = [] val_dataloaders = [] test_dataloaders = [] diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index e968681c23..313598788a 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -378,3 +378,18 @@ then echo "The number of save models $cnt is not equal to the specified topk 3" exit -1 fi + +echo "**************[Multi-task gen embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, load from saved model" +python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_mt/ --num-trainers $NUM_TRAINERS --use-mini-batch-infer false --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-embed-path /data/gsgnn_mt/save-emb/ --restore-model-path /data/gsgnn_mt/epoch-2/ --restore-model-layers embed,gnn --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True + +error_and_exit $? + +cnt=$(ls -l /data/gsgnn_mt/save-emb/ | wc -l) +cnt=$[cnt - 1] +if test $cnt != 2 +then + echo "The number of saved embs $cnt is not equal to 2 (for movie and user)." +fi + +# Multi-task will save node embeddings of all the nodes. +python3 $GS_HOME/tests/end2end-tests/check_infer.py --train-embout /data/gsgnn_mt/emb/ --infer-embout /data/gsgnn_mt/save-emb/ --link-prediction \ No newline at end of file