-
Notifications
You must be signed in to change notification settings - Fork 377
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0e0904d
commit 8a8f6fb
Showing
8 changed files
with
126 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from . import (baai, baichuan, deepseek, gemma, glm, internlm, llama, llava, mamba, microsoft, minicpm, mistral, | ||
qwen, telechat, yi) | ||
from . import (baai, baichuan, deepseek, gemma, glm, internlm, llama, llava, mamba, microsoft, minicpm, mistral, qwen, | ||
telechat, yi) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
class TorchAccTrainer: | ||
|
||
def get_train_dataloader(self): | ||
if trainer.is_datasets_available(): | ||
import datasets | ||
|
||
if self.train_dataset is None: | ||
raise ValueError('Trainer: training requires a train_dataset.') | ||
|
||
train_dataset = self.train_dataset | ||
data_collator = self.data_collator | ||
|
||
if trainer.is_datasets_available() and isinstance(train_dataset, datasets.Dataset): | ||
train_dataset = self._remove_unused_columns(train_dataset, description='training') | ||
else: | ||
data_collator = self._get_collator_with_removed_columns(data_collator, description='training') | ||
|
||
return ta_train_dataloader(train_dataset, data_collator, self._get_train_sampler(), self.args, | ||
self._train_batch_size) | ||
|
||
def get_eval_dataloader(self, eval_dataset=None): | ||
if trainer.is_datasets_available(): | ||
import datasets | ||
|
||
if eval_dataset is None and self.eval_dataset is None: | ||
raise ValueError('Trainer: evaluation requires an eval_dataset.') | ||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset | ||
data_collator = self.data_collator | ||
|
||
if trainer.is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): | ||
eval_dataset = self._remove_unused_columns(eval_dataset, description='evaluation') | ||
else: | ||
data_collator = self._get_collator_with_removed_columns(data_collator, description='evaluation') | ||
|
||
return ta_eval_dataloader(eval_dataset, data_collator, self._get_eval_sampler(eval_dataset), self.args) | ||
|
||
def get_test_dataloader(self, test_dataset): | ||
if trainer.is_datasets_available(): | ||
import datasets | ||
|
||
data_collator = self.data_collator | ||
|
||
if trainer.is_datasets_available() and isinstance(test_dataset, datasets.Dataset): | ||
test_dataset = self._remove_unused_columns(test_dataset, description='test') | ||
else: | ||
data_collator = self._get_collator_with_removed_columns(data_collator, description='test') | ||
|
||
return ta_test_dataloader(test_dataset, data_collator, self._get_eval_sampler(test_dataset), self.args) | ||
|
||
def _save_tpu(self, output_dir: Optional[str] = None): | ||
import torch_xla.core.xla_model as xm | ||
|
||
# Compatible with swift and peft | ||
output_dir = output_dir if output_dir is not None else self.args.output_dir | ||
|
||
if xm.is_master_ordinal(local=False): | ||
os.makedirs(output_dir, exist_ok=True) | ||
# configuration.json | ||
model_dir = getattr(self.model, 'model_dir', None) | ||
if model_dir is not None: | ||
src_path = os.path.join(model_dir, 'configuration.json') | ||
dst_path = os.path.join(output_dir, 'configuration.json') | ||
if os.path.exists(src_path): | ||
shutil.copy(src_path, dst_path) | ||
else: | ||
self._create_configuration_file(self.model, output_dir) | ||
self._save_sft_args(output_dir) | ||
# generation_config | ||
generation_config = getattr(self.args, 'generation_config', None) | ||
if generation_config is not None: | ||
generation_config.save_pretrained(output_dir) | ||
|
||
# model | ||
if self.sft_args.fsdp_num > 1: | ||
save_ta_fsdp_checkpoint(self.model, self.tokenizer, self.args, output_dir) | ||
else: | ||
save_ta_ddp_checkpoint(self.model, self.tokenizer, self.args, output_dir) | ||
sft_args = getattr(self, 'sft_args', None) | ||
|
||
# additional files | ||
if xm.is_master_ordinal(local=False): | ||
if sft_args is not None and sft_args.sft_type == 'full': | ||
additional_files = getattr(self.args, 'additional_saved_files', | ||
None) or [] + ['preprocessor_config.json'] | ||
if model_dir is not None: | ||
for file in additional_files: | ||
src_path = os.path.join(model_dir, file) | ||
dst_path = os.path.join(output_dir, file) | ||
if os.path.isfile(src_path): | ||
shutil.copy(src_path, dst_path) | ||
elif os.path.isdir(src_path): | ||
shutil.copytree(src_path, dst_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters