Skip to content

Commit

Permalink
fix train
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Nov 18, 2024
1 parent 2811ee1 commit 0c26f9b
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 22 deletions.
17 changes: 7 additions & 10 deletions swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def supported_tuners(self):
def adapters_can_be_merged(self):
return TunerArguments.adapters_can_be_merged

def load_args(self, checkpoint_dir: str) -> None:
"""Load specific attributes from sft_args.json"""
def load_args_from_ckpt(self, checkpoint_dir: str) -> None:
"""Load specific attributes from args.json"""
from swift.llm import SftArguments, ExportArguments, InferArguments
if isinstance(self, SftArguments):
self.resume_from_checkpoint = to_abspath(self.resume_from_checkpoint, True)
Expand All @@ -78,24 +78,21 @@ def load_args(self, checkpoint_dir: str) -> None:
# read settings
all_keys = list(f.name for f in fields(self.__class__)) + ['train_type']
data_keys = list(f.name for f in fields(DataArguments))
covered_keys = ['system']
for key in all_keys:
if not self.load_dataset_config and key in data_keys:
continue
value = getattr(self, key)
value = getattr(self, key, None)
old_value = old_args.get(key) # value in checkpoint
if old_value and not value:
# TODO: check; system=''
if key in covered_keys or old_value and not value:
setattr(self, key, old_value)

def save_args(self) -> None:
"""TODO"""
from swift.llm import InferArguments
if isinstance(self, InferArguments):
return
# TODO:check
self.args_type = self.__class__.__name__
if is_master():
fpath = os.path.join(self.output_dir, 'args.json')
logger.info(f'The {args.__class__.__name__} will be saved in: {fpath}')
logger.info(f'The {self.__class__.__name__} will be saved in: {fpath}')
with open(fpath, 'w', encoding='utf-8') as f:
json.dump(check_json_format(args.__dict__), f, ensure_ascii=False, indent=2)
json.dump(check_json_format(self.__dict__), f, ensure_ascii=False, indent=2)
2 changes: 2 additions & 0 deletions swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def _init_stream(self):
logger.info('Setting args.stream: False')

def __post_init__(self) -> None:
if self.ckpt_dir and self.load_args:
self.load_args_from_ckpt(self.ckpt_dir)
BaseArguments.__post_init__(self)
VllmArguments.__post_init__(self)
MergeArguments.__post_init__(self)
Expand Down
9 changes: 5 additions & 4 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,17 @@ class SftArguments(MegatronArguments, TorchAccArguments, TunerArguments, Seq2Seq
acc_strategy: Literal['token', 'sentence'] = 'token'

def __post_init__(self) -> None:
if self.resume_from_checkpoint:
self.load_args_from_ckpt(self.resume_from_checkpoint)
if self.train_type == 'full':
self.model_id_or_path = self.resume_from_checkpoint
BaseArguments.__post_init__(self)
Seq2SeqTrainingOverrideArguments.__post_init__(self)
TunerArguments.__post_init__(self)
TorchAccArguments.__post_init__(self)
MegatronArguments.__post_init__(self)
self._handle_pai_compat()
self.prepare_deepspeed()
if self.resume_from_checkpoint:
self.load_args(self.resume_from_checkpoint)
if self.train_type == 'full':
self.model_id_or_path = self.resume_from_checkpoint

self.rank, self.local_rank, self.global_world_size, self.local_world_size = get_dist_setting()

Expand All @@ -200,6 +200,7 @@ def __post_init__(self) -> None:
else:
self.init_megatron()
self._add_version()
self.save_args()

def prepare_deepspeed(self):
"""Prepare deepspeed settings"""
Expand Down
7 changes: 3 additions & 4 deletions swift/llm/template/template/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..constant import TemplateType
from ..register import TemplateMeta, register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context, Prompt, findall
from ..utils import Context, Prompt, fetch_one, findall
from ..vision_utils import load_audio_qwen, load_batch, load_video_qwen2
from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta

Expand All @@ -35,11 +35,10 @@ class Qwen2_5TemplateMeta(QwenTemplateMeta):
class QwenVLTemplate(Template):
load_medias = False

def check_inputs(self, inputs: StdTemplateInputs):
if self.infer_backend in {'lmdeploy', 'vllm'}:
def _check_inputs(self, inputs: StdTemplateInputs):
if self.mode in {'lmdeploy', 'vllm'}:
return
images = inputs.images
from ..utils import fetch_one
assert not images or isinstance(fetch_one(images), str), 'QwenVL only supports datasets with images paths!'

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
Expand Down
5 changes: 4 additions & 1 deletion swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,11 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
self.model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
# training_args.bin
torch.save(self.args, os.path.join(output_dir, 'training_args.bin'))

self._save_converted_model(output_dir)
# args.json
args_path = os.path.join(os.path.dirname(output_dir), 'args.json')
if os.path.exists(args_path):
shutil.copy(args_path, os.path.join(output_dir, 'args.json'))

is_adapter = isinstance(self.model, (SwiftModel, PeftModel))
# tokenizer
Expand Down
3 changes: 3 additions & 0 deletions swift/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar, Union

import numpy as np
import torch
import torch.distributed as dist
from transformers import HfArgumentParser, enable_full_determinism, set_seed
from transformers.trainer import TrainingArguments
Expand All @@ -24,6 +25,8 @@
def check_json_format(obj: Any, token_safe: bool = True) -> Any:
if obj is None or isinstance(obj, (int, float, str, complex)): # bool is a subclass of int
return obj
if isinstance(obj, torch.dtype):
return str(obj)[len('torch.'):]

if isinstance(obj, Sequence):
res = []
Expand Down
8 changes: 5 additions & 3 deletions tests/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
'per_device_train_batch_size': 2,
'save_steps': 5,
'gradient_accumulation_steps': 4,
'logging_first_step': True
'logging_first_step': True,
'metric_for_best_model': 'loss'
}


Expand All @@ -16,13 +17,14 @@ def test_llm():


def test_mllm():
from swift.llm import sft_main, SftArguments
from swift.llm import sft_main, SftArguments, infer_main, InferArguments
result = sft_main(
SftArguments(
model='qwen/Qwen2-VL-7B-Instruct',
dataset=['modelscope/coco_2014_caption:validation#20', 'AI-ModelScope/alpaca-gpt4-data-en#20'],
**kwargs))
print()
last_model_checkpoint = result['last_model_checkpoint']
infer_main(InferArguments(ckpt_dir=last_model_checkpoint, load_dataset_config=True, merge_lora=True))


if __name__ == '__main__':
Expand Down

0 comments on commit 0c26f9b

Please sign in to comment.