Skip to content

Commit

Permalink
update train
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Nov 16, 2024
1 parent 0e0904d commit 8a8f6fb
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 148 deletions.
4 changes: 2 additions & 2 deletions swift/llm/model/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import (baai, baichuan, deepseek, gemma, glm, internlm, llama, llava, mamba, microsoft, minicpm, mistral,
qwen, telechat, yi)
from . import (baai, baichuan, deepseek, gemma, glm, internlm, llama, llava, mamba, microsoft, minicpm, mistral, qwen,
telechat, yi)
11 changes: 5 additions & 6 deletions swift/llm/model/model_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,11 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
vision_tower='model.vision',
))

register_model_arch(
MultiModelKeys(
ModelArch.florence,
language_model='language_model',
vision_tower='vision_tower',
))
register_model_arch(MultiModelKeys(
ModelArch.florence,
language_model='language_model',
vision_tower='vision_tower',
))

register_model_arch(
MultiModelKeys(
Expand Down
2 changes: 0 additions & 2 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,6 @@ def pre_forward_hook(self,

def set_infer_backend(self, infer_backend: Literal['vllm', 'lmdeploy', 'pt']) -> None:
self.infer_backend = infer_backend
if infer_backend in {'vllm', 'lmdeploy'}:
self.remove_post_encode_hook()

def register_post_encode_hook(self, models: List[nn.Module]) -> None:
"""This function is important for multi-modal training, as it registers the post_encode method
Expand Down
6 changes: 3 additions & 3 deletions swift/llm/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def run(self):
preprocess_logits_for_metrics = preprocess_logits_for_acc

trainer_cls = TrainerFactory.get_trainer_cls(args)
trainer_cls(
trainer = trainer_cls(
model=self.model,
args=self.args.training_args,
data_collator=data_collator,
Expand All @@ -163,10 +163,10 @@ def run(self):
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
callbacks=self.callbacks,
optimizers=optimizers,
sequence_parallel_size=args.sequence_parallel_size,
tokenizer=self.tokenizer,
check_model=args.check_model,
)
template.register_post_encode_hook([self.model])
trainer.train(args.training_args.resume_from_checkpoint)

def _prepare_optimizers(self, train_dataset):
args = self.args
Expand Down
5 changes: 4 additions & 1 deletion swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ class SwiftArgumentsMixin:
acc_strategy: str = field(default='token', metadata={'choices': ['token', 'sentence']})
loss_name: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'})
additional_saved_files: Optional[List[str]] = None
# torchacc
sequence_parallel_size: int = 1
check_model: bool = True
train_sampler_random: bool = True

# torchacc
metric_warmup_step: Optional[float] = 0
train_dataset_sample: Optional[int] = -1

Expand Down
153 changes: 20 additions & 133 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,20 @@

class SwiftMixin:

def __init__(self,
model: Union[PreTrainedModel, Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[HfDataset] = None,
eval_dataset: Optional[Union[HfDataset, Dict[str, HfDataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
*,
check_model: bool = True,
sequence_parallel_size: int = 1) -> None:
def __init__(
self,
model: Union[PreTrainedModel, Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[HfDataset] = None,
eval_dataset: Optional[Union[HfDataset, Dict[str, HfDataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor],
torch.Tensor]] = None) -> None:
# if check_model and hasattr(model, 'model_dir'):
# check_local_model_is_latest(
# model.model_dir,
Expand All @@ -74,9 +73,9 @@ def __init__(self,
# Invoke.THIRD_PARTY: kwargs.pop(Invoke.THIRD_PARTY, Invoke.SWIFT),
# })

if sequence_parallel_size > 1:
if args.sequence_parallel_size > 1:
from swift.trainers.xtuner import init_sequence_parallel_xtuner
init_sequence_parallel_xtuner(self.sequence_parallel_size)
init_sequence_parallel_xtuner(args.sequence_parallel_size)

super().__init__(
model=model,
Expand Down Expand Up @@ -163,53 +162,6 @@ def _load_optimizer_and_scheduler(self, checkpoint):
self.optimizer, self.lr_scheduler = ta_load_optimizer_and_scheduler(self.optimizer, self.lr_scheduler,
checkpoint, self.args.device)

def _save_tpu(self, output_dir: Optional[str] = None):
if not use_torchacc():
return super()._save_tpu(output_dir)

import torch_xla.core.xla_model as xm

# Compatible with swift and peft
output_dir = output_dir if output_dir is not None else self.args.output_dir

if xm.is_master_ordinal(local=False):
os.makedirs(output_dir, exist_ok=True)
# configuration.json
model_dir = getattr(self.model, 'model_dir', None)
if model_dir is not None:
src_path = os.path.join(model_dir, 'configuration.json')
dst_path = os.path.join(output_dir, 'configuration.json')
if os.path.exists(src_path):
shutil.copy(src_path, dst_path)
else:
self._create_configuration_file(self.model, output_dir)
self._save_sft_args(output_dir)
# generation_config
generation_config = getattr(self.args, 'generation_config', None)
if generation_config is not None:
generation_config.save_pretrained(output_dir)

# model
if self.sft_args.fsdp_num > 1:
save_ta_fsdp_checkpoint(self.model, self.tokenizer, self.args, output_dir)
else:
save_ta_ddp_checkpoint(self.model, self.tokenizer, self.args, output_dir)
sft_args = getattr(self, 'sft_args', None)

# additional files
if xm.is_master_ordinal(local=False):
if sft_args is not None and sft_args.sft_type == 'full':
additional_files = getattr(self.args, 'additional_saved_files',
None) or [] + ['preprocessor_config.json']
if model_dir is not None:
for file in additional_files:
src_path = os.path.join(model_dir, file)
dst_path = os.path.join(output_dir, file)
if os.path.isfile(src_path):
shutil.copy(src_path, dst_path)
elif os.path.isdir(src_path):
shutil.copytree(src_path, dst_path)

def _save(self, output_dir: Optional[str] = None, state_dict=None):
"""Compatible with swift and peft"""
# If we are executing this function, we are the process zero, so we don't check for that.
Expand Down Expand Up @@ -357,8 +309,7 @@ def _save_only_model(self, model, trial, metrics=None):
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
train_sampler_random = self.args.train_sampler_random
if train_sampler_random:
if self.args.train_sampler_random:
return super()._get_train_sampler()
else:
return self._get_eval_sampler(self.train_dataset)
Expand Down Expand Up @@ -406,19 +357,8 @@ def _sorted_checkpoints(self,
return checkpoints_sorted

def train(self, resume_from_checkpoint: Optional[Union[str, bool]] = None, *args, **kwargs) -> torch.Tensor:
sft_args = getattr(self, 'sft_args', None)
self._resume_only_model = getattr(sft_args, 'resume_only_model', False)
if self._resume_only_model:
# Control the behavior of "resume_from_checkpoint" by swift.
self._resume_from_checkpoint = resume_from_checkpoint
resume_from_checkpoint = None
if self._resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_fsdp_enabled:
self._load_from_checkpoint(self._resume_from_checkpoint)

self._save_initial_model(self.args.output_dir)
res = super().train(resume_from_checkpoint, *args, **kwargs)
self._resume_from_checkpoint = None
return res
return super().train(resume_from_checkpoint, *args, **kwargs)

def _load_best_model(self):
# Compatible with transformers>=4.35 (deepspeed)
Expand Down Expand Up @@ -500,8 +440,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
else:
self.create_optimizer()
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
super().create_optimizer_and_scheduler(num_training_steps=num_training_steps)

def create_optimizer(self):
opt_model = self.model
Expand Down Expand Up @@ -541,64 +480,12 @@ def create_optimizer(self):
return self.optimizer

def get_train_dataloader(self):
if self.sequence_parallel_size > 1:
if self.args.sequence_parallel_size > 1:
from swift.trainers.xtuner import get_xtuner_train_dataloader
return get_xtuner_train_dataloader(self)
elif use_torchacc():
if trainer.is_datasets_available():
import datasets

if self.train_dataset is None:
raise ValueError('Trainer: training requires a train_dataset.')

train_dataset = self.train_dataset
data_collator = self.data_collator

if trainer.is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description='training')
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description='training')

return ta_train_dataloader(train_dataset, data_collator, self._get_train_sampler(), self.args,
self._train_batch_size)
else:
return super().get_train_dataloader()

def get_eval_dataloader(self, eval_dataset=None):
if not use_torchacc():
return super().get_eval_dataloader(eval_dataset)
else:
if trainer.is_datasets_available():
import datasets

if eval_dataset is None and self.eval_dataset is None:
raise ValueError('Trainer: evaluation requires an eval_dataset.')
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
data_collator = self.data_collator

if trainer.is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description='evaluation')
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description='evaluation')

return ta_eval_dataloader(eval_dataset, data_collator, self._get_eval_sampler(eval_dataset), self.args)

def get_test_dataloader(self, test_dataset):
if not use_torchacc():
return super().get_test_dataloader(test_dataset)
else:
if trainer.is_datasets_available():
import datasets

data_collator = self.data_collator

if trainer.is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description='test')
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description='test')

return ta_test_dataloader(test_dataset, data_collator, self._get_eval_sampler(test_dataset), self.args)


class ModelWrapper(nn.Module):
# compat zero3 & rlhf
Expand Down
92 changes: 92 additions & 0 deletions swift/trainers/torchacc_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
class TorchAccTrainer:

def get_train_dataloader(self):
if trainer.is_datasets_available():
import datasets

if self.train_dataset is None:
raise ValueError('Trainer: training requires a train_dataset.')

train_dataset = self.train_dataset
data_collator = self.data_collator

if trainer.is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description='training')
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description='training')

return ta_train_dataloader(train_dataset, data_collator, self._get_train_sampler(), self.args,
self._train_batch_size)

def get_eval_dataloader(self, eval_dataset=None):
if trainer.is_datasets_available():
import datasets

if eval_dataset is None and self.eval_dataset is None:
raise ValueError('Trainer: evaluation requires an eval_dataset.')
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
data_collator = self.data_collator

if trainer.is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description='evaluation')
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description='evaluation')

return ta_eval_dataloader(eval_dataset, data_collator, self._get_eval_sampler(eval_dataset), self.args)

def get_test_dataloader(self, test_dataset):
if trainer.is_datasets_available():
import datasets

data_collator = self.data_collator

if trainer.is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description='test')
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description='test')

return ta_test_dataloader(test_dataset, data_collator, self._get_eval_sampler(test_dataset), self.args)

def _save_tpu(self, output_dir: Optional[str] = None):
import torch_xla.core.xla_model as xm

# Compatible with swift and peft
output_dir = output_dir if output_dir is not None else self.args.output_dir

if xm.is_master_ordinal(local=False):
os.makedirs(output_dir, exist_ok=True)
# configuration.json
model_dir = getattr(self.model, 'model_dir', None)
if model_dir is not None:
src_path = os.path.join(model_dir, 'configuration.json')
dst_path = os.path.join(output_dir, 'configuration.json')
if os.path.exists(src_path):
shutil.copy(src_path, dst_path)
else:
self._create_configuration_file(self.model, output_dir)
self._save_sft_args(output_dir)
# generation_config
generation_config = getattr(self.args, 'generation_config', None)
if generation_config is not None:
generation_config.save_pretrained(output_dir)

# model
if self.sft_args.fsdp_num > 1:
save_ta_fsdp_checkpoint(self.model, self.tokenizer, self.args, output_dir)
else:
save_ta_ddp_checkpoint(self.model, self.tokenizer, self.args, output_dir)
sft_args = getattr(self, 'sft_args', None)

# additional files
if xm.is_master_ordinal(local=False):
if sft_args is not None and sft_args.sft_type == 'full':
additional_files = getattr(self.args, 'additional_saved_files',
None) or [] + ['preprocessor_config.json']
if model_dir is not None:
for file in additional_files:
src_path = os.path.join(model_dir, file)
dst_path = os.path.join(output_dir, file)
if os.path.isfile(src_path):
shutil.copy(src_path, dst_path)
elif os.path.isdir(src_path):
shutil.copytree(src_path, dst_path)
1 change: 0 additions & 1 deletion swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def prediction_step(
generation_inputs = generate_inputs[self.model.main_input_name]

generated_tokens = generated_tokens[:, generation_inputs.shape[1]:]
gen_len = len(generated_tokens[0])

# in case the batch is shorter than max length, the output should be padded
if gen_kwargs.get('max_length') is not None and generated_tokens.shape[-1] < gen_kwargs['max_length']:
Expand Down

0 comments on commit 8a8f6fb

Please sign in to comment.