Skip to content

Commit

Permalink
Update gsgnn_node_emb for multi-task learning (awslabs#860)
Browse files Browse the repository at this point in the history
*Issue #, if available:*
awslabs#789 

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Jun 6, 2024
1 parent 5e189df commit 0981cee
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 31 deletions.
2 changes: 2 additions & 0 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def __init__(self, cmd_args):
# parse multi task learning config and save it into self._multi_tasks
if multi_task_config is not None:
self._parse_multi_tasks(multi_task_config)
else:
self._multi_tasks = None

def set_attributes(self, configuration):
"""Set class attributes from 2nd level arguments in yaml config"""
Expand Down
83 changes: 52 additions & 31 deletions python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@
from graphstorm.config import GSConfig
from graphstorm.utils import rt_profiler, sys_tracker, get_device, use_wholegraph
from graphstorm.dataloading import GSgnnData
from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION,
BUILTIN_TASK_LINK_PREDICTION)
from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION,
BUILTIN_TASK_LINK_PREDICTION,
GRAPHSTORM_MODEL_ALL_LAYERS,
GRAPHSTORM_MODEL_EMBED_LAYER,
GRAPHSTORM_MODEL_GNN_LAYER,
GRAPHSTORM_MODEL_DECODER_LAYER)
from graphstorm.inference import GSgnnEmbGenInferer
from graphstorm.utils import get_lm_ntypes
from graphstorm.model.multitask_gnn import GSgnnMultiTaskSharedEncoderModel

def main(config_args):
""" main function
Expand All @@ -44,12 +49,14 @@ def main(config_args):
if gs.get_rank() == 0:
tracker.log_params(config.__dict__)

assert config.task_type in [BUILTIN_TASK_LINK_PREDICTION,
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION], \
f"Not supported for task type: {config.task_type}"
if config.multi_tasks is None:
# if not multi-task, check task type
assert config.task_type in [BUILTIN_TASK_LINK_PREDICTION,
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION], \
f"Not supported for task type: {config.task_type}"

input_data = GSgnnData(config.part_config,
node_feat_field=config.node_feat_name,
Expand All @@ -63,36 +70,50 @@ def main(config_args):
"restore model path cannot be none for gs_gen_node_embeddings"

# load the model
if config.task_type == BUILTIN_TASK_LINK_PREDICTION:
model = gs.create_builtin_lp_gnn_model(input_data.g, config, train_task=False)
elif config.task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}:
model = gs.create_builtin_node_gnn_model(input_data.g, config, train_task=False)
elif config.task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}:
model = gs.create_builtin_edge_gnn_model(input_data.g, config, train_task=False)
if config.multi_tasks:
# Only support multi-task shared encoder model.
model = GSgnnMultiTaskSharedEncoderModel(config.alpha_l2norm)
gs.gsf.set_encoder(model, input_data.g, config, train_task=False)
assert config.restore_model_layers is not GRAPHSTORM_MODEL_ALL_LAYERS, \
"When computing node embeddings with GSgnnMultiTaskSharedEncoderModel, " \
"please set --restore-model-layers to " \
f"{GRAPHSTORM_MODEL_EMBED_LAYER}, {GRAPHSTORM_MODEL_GNN_LAYER}." \
f"Please do not include {GRAPHSTORM_MODEL_DECODER_LAYER}, " \
f"but we get {config.restore_model_layers}"
else:
raise TypeError("Not supported for task type: ", config.task_type)
if config.task_type == BUILTIN_TASK_LINK_PREDICTION:
model = gs.create_builtin_lp_gnn_model(input_data.g, config, train_task=False)
elif config.task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}:
model = gs.create_builtin_node_gnn_model(input_data.g, config, train_task=False)
elif config.task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}:
model = gs.create_builtin_edge_gnn_model(input_data.g, config, train_task=False)
else:
raise TypeError("Not supported for task type: ", config.task_type)
model.restore_model(config.restore_model_path,
model_layer_to_load=config.restore_model_layers)

# start to infer
emb_generator = GSgnnEmbGenInferer(model)
emb_generator.setup_device(device=get_device())

task_type = config.task_type
# infer ntypes must be sorted for node embedding saving
if task_type == BUILTIN_TASK_LINK_PREDICTION:
if config.multi_tasks:
# infer_ntypes = None means all node types.
infer_ntypes = None
elif task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}:
# TODO(xiangsx): Support multi-task on multiple node types.
infer_ntypes = [config.target_ntype]
elif task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}:
infer_ntypes = set()
for etype in config.target_etype:
infer_ntypes.add(etype[0])
infer_ntypes.add(etype[2])
infer_ntypes = sorted(list(infer_ntypes))
else:
raise TypeError("Not supported for task type: ", task_type)
task_type = config.task_type
# infer ntypes must be sorted for node embedding saving
if task_type == BUILTIN_TASK_LINK_PREDICTION:
infer_ntypes = None
elif task_type in {BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION}:
infer_ntypes = [config.target_ntype]
elif task_type in {BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION}:
infer_ntypes = set()
for etype in config.target_etype:
infer_ntypes.add(etype[0])
infer_ntypes.add(etype[2])
infer_ntypes = sorted(list(infer_ntypes))
else:
raise TypeError("Not supported for task type: ", task_type)

emb_generator.infer(input_data, infer_ntypes,
save_embed_path=config.save_embed_path,
Expand Down
2 changes: 2 additions & 0 deletions python/graphstorm/run/gsgnn_mt/gsgnn_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ def main(config_args):
gs.gsf.set_encoder(model, train_data.g, config, train_task=True)

tasks = config.multi_tasks
assert tasks is not None, \
"The multi_task_learning configure block should not be empty."
train_dataloaders = []
val_dataloaders = []
test_dataloaders = []
Expand Down
15 changes: 15 additions & 0 deletions tests/end2end-tests/graphstorm-mt/mgpu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,18 @@ then
echo "The number of save models $cnt is not equal to the specified topk 3"
exit -1
fi

echo "**************[Multi-task gen embedding] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, load from saved model"
python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_mt/ --num-trainers $NUM_TRAINERS --use-mini-batch-infer false --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_ec_er_lp.yaml --save-embed-path /data/gsgnn_mt/save-emb/ --restore-model-path /data/gsgnn_mt/epoch-2/ --restore-model-layers embed,gnn --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True

error_and_exit $?

cnt=$(ls -l /data/gsgnn_mt/save-emb/ | wc -l)
cnt=$[cnt - 1]
if test $cnt != 2
then
echo "The number of saved embs $cnt is not equal to 2 (for movie and user)."
fi

# Multi-task will save node embeddings of all the nodes.
python3 $GS_HOME/tests/end2end-tests/check_infer.py --train-embout /data/gsgnn_mt/emb/ --infer-embout /data/gsgnn_mt/save-emb/ --link-prediction

0 comments on commit 0981cee

Please sign in to comment.