# Copyright (c) OpenMMLab. All rights reserved.
from bitsandbytes.optim import PagedAdamW32bit
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
from peft import LoraConfig
from transformers import (AutoModelForCausalLM, AutoTokenizer,
from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import template_map_fn_factory
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE
#######################################################################
#######################################################################
pretrained_model_name_or_path = '/home/yhfu/04Project/01AI/model/Shanghai_AI_Laboratory/internlm-chat-7b'
data_path = 'train.jsonl'
prompt_template = PROMPT_TEMPLATE.internlm_chat
pack_to_max_length = True
batch_size = 3 # per_device
accumulative_counts = 16
dataloader_num_workers = 0
optim_type = PagedAdamW32bit
max_norm = 1 # grad clip
# Evaluate the generation performance during the training
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
type=SupervisedFinetune,
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
type=process_hf_dataset,
dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)),
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)
train_dataloader = dict(
num_workers=dataloader_num_workers,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn))
#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
convert_to_iter_based=True)
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
#######################################################################
#######################################################################
# Log the dialogue periodically during the training process, optional
dict(type=DatasetInfoHook, tokenizer=tokenizer),
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
prompt_template=prompt_template)
# configure default hooks
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per epoch.
checkpoint=dict(type=CheckpointHook, interval=1),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
# whether to enable cudnn benchmark
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
# load from which checkpoint
# whether to resume training from the loaded checkpoint
# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)