adapt transformers>=4.46

This commit is contained in:
zhipuch 2024-11-06 11:40:09 +00:00
parent af1d4f2f11
commit a0c568753a
2 changed files with 6 additions and 10 deletions

View File

@ -56,14 +56,14 @@ class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
class Seq2SeqTrainer(_Seq2SeqTrainer):
# Not Support for apex
def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:
# Not Support for apex. transformers>=4.46 require additional args: num_items_in_batch
def training_step(self, model: nn.Module, inputs: dict[str, Any], num_items_in_batch=None) -> torch.Tensor:
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
if self.args.n_gpu > 1:
loss = loss.mean()
@ -353,7 +353,6 @@ def load_tokenizer_and_model(
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False,
use_cache=False,
torch_dtype=torch.bfloat16 # Must use BFloat 16
)
@ -363,7 +362,6 @@ def load_tokenizer_and_model(
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False,
use_cache=False,
torch_dtype=torch.bfloat16
)

View File

@ -57,14 +57,14 @@ class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
class Seq2SeqTrainer(_Seq2SeqTrainer):
# Not Support for apex
def training_step(self, model: nn.Module, inputs: dict[str, Any]) -> torch.Tensor:
# Not Support for apex. transformers>=4.46 require additional args: num_items_in_batch
def training_step(self, model: nn.Module, inputs: dict[str, Any], num_items_in_batch=None) -> torch.Tensor:
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
if self.args.n_gpu > 1:
loss = loss.mean()
@ -399,7 +399,6 @@ def load_tokenizer_and_model(
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False, #if transformers>=4.46 and load glm-4-9b-chat-hf, delete this
use_cache=False,
torch_dtype=torch.bfloat16 # Must use BFloat 16
)
@ -409,7 +408,6 @@ def load_tokenizer_and_model(
model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
empty_init=False, #if transformers>=4.46 and load glm-4-9b-chat-hf, delete this
use_cache=False,
torch_dtype=torch.bfloat16
)