dtype use bf16
This commit is contained in:
parent
b9d6d3863f
commit
3dec01c6d5
|
@ -11,7 +11,7 @@ Read this in [English](README_en.md)
|
||||||
|
|
||||||
## 项目更新
|
## 项目更新
|
||||||
|
|
||||||
- 🔥 **News**: ``2024/7/1``: 我们更新了 GLM-4V-9B 的多模态微调,您需要更新我们的模型仓库的运行文件和配置文件, 以支持这个功能,更多微调细节 (例如数据集格式,显存要求),请前往[查看](finetune_demo)。
|
- 🔥 **News**: ``2024/7/1``: 我们更新了 GLM-4V-9B 的微调,您需要更新我们的模型仓库的运行文件和配置文件, 以支持这个功能,更多微调细节 (例如数据集格式,显存要求),请前往 [查看](finetune_demo)。
|
||||||
- 🔥 **News**: ``2024/6/28``: 我们与英特尔技术团队合作,改进了 GLM-4-9B-Chat 的 ITREX 和 OpenVINO 部署教程。您可以使用英特尔 CPU/GPU 设备高效部署 GLM-4-9B 开源模型。欢迎访问 [查看](intel_device_demo)。
|
- 🔥 **News**: ``2024/6/28``: 我们与英特尔技术团队合作,改进了 GLM-4-9B-Chat 的 ITREX 和 OpenVINO 部署教程。您可以使用英特尔 CPU/GPU 设备高效部署 GLM-4-9B 开源模型。欢迎访问 [查看](intel_device_demo)。
|
||||||
- 🔥 **News**: ``2024/6/24``: 我们更新了模型仓库的运行文件和配置文件,支持 Flash Attention 2,
|
- 🔥 **News**: ``2024/6/24``: 我们更新了模型仓库的运行文件和配置文件,支持 Flash Attention 2,
|
||||||
请更新模型配置文件并参考 `basic_demo/trans_cli_demo.py` 中的示例代码。
|
请更新模型配置文件并参考 `basic_demo/trans_cli_demo.py` 中的示例代码。
|
||||||
|
|
|
@ -33,7 +33,7 @@ model = AutoModel.from_pretrained(
|
||||||
MODEL_PATH,
|
MODEL_PATH,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
# attn_implementation="flash_attention_2", # Use Flash Attention
|
# attn_implementation="flash_attention_2", # Use Flash Attention
|
||||||
# torch_dtype=torch.bfloat16, # using flash-attn must use bfloat16 or float16,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
).eval()
|
).eval()
|
||||||
|
|
||||||
|
|
|
@ -271,7 +271,7 @@ def process_batch(
|
||||||
[message],
|
[message],
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
padding=True,
|
padding=True
|
||||||
)
|
)
|
||||||
new_input_ids = new_input_ids_all['input_ids'][0][2:]
|
new_input_ids = new_input_ids_all['input_ids'][0][2:]
|
||||||
new_attention_mask = new_input_ids_all['attention_mask'][0][2:]
|
new_attention_mask = new_input_ids_all['attention_mask'][0][2:]
|
||||||
|
@ -453,6 +453,7 @@ def main(
|
||||||
batched=True,
|
batched=True,
|
||||||
)
|
)
|
||||||
print('train_dataset:', train_dataset)
|
print('train_dataset:', train_dataset)
|
||||||
|
|
||||||
val_dataset = data_manager.get_dataset(
|
val_dataset = data_manager.get_dataset(
|
||||||
Split.VALIDATION,
|
Split.VALIDATION,
|
||||||
functools.partial(
|
functools.partial(
|
||||||
|
@ -463,6 +464,7 @@ def main(
|
||||||
),
|
),
|
||||||
batched=True,
|
batched=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if val_dataset is not None:
|
if val_dataset is not None:
|
||||||
print('val_dataset:', val_dataset)
|
print('val_dataset:', val_dataset)
|
||||||
test_dataset = data_manager.get_dataset(
|
test_dataset = data_manager.get_dataset(
|
||||||
|
|
Loading…
Reference in New Issue