first commit

This commit is contained in:
xxl 2025-01-06 15:20:48 +08:00
parent 7b1e2a820b
commit 8b8b0500e8
17 changed files with 414405 additions and 2 deletions

100
README.md
View File

@ -1,3 +1,99 @@
# ultravox-v0_3_a14192466442186752847586
---
language:
- en
license: mit
library_name: transformers
datasets:
- fixie-ai/librispeech_asr
- fixie-ai/common_voice_17_0
pipeline_tag: audio-text-to-text
---
ultravox-v0_3
# Model Card for Ultravox
Ultravox is a multimodal Speech LLM built around a pretrained [Llama3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) and [Whisper-small](https://huggingface.co/openai/whisper-small) backbone.\
See https://ultravox.ai for the GitHub repo and more information.
## Model Details
### Model Description
Ultravox is a multimodal model that can consume both speech and text as input (e.g., a text system prompt and voice user message).
The input to the model is given as a text prompt with a special `<|audio|>` pseudo-token, and the model processor will replace this magic token with embeddings derived from the input audio.
Using the merged embeddings as input, the model will then generate output text as usual.
In a future revision of Ultravox, we plan to expand the token vocabulary to support generation of semantic and acoustic audio tokens, which can then be fed to a vocoder to produce voice output.
No preference tuning has been applied to this revision of the model.
- **Developed by:** Fixie.ai
- **License:** MIT
### Model Sources
- **Repository:** https://ultravox.ai
- **Demo:** See repo
## Usage
Think of the model as an LLM that can also hear and understand speech. As such, it can be used as a voice agent, and also to do speech-to-speech translation, analysis of spoken audio, etc.
To use the model, try the following:
```python
# pip install transformers peft librosa
import transformers
import numpy as np
import librosa
pipe = transformers.pipeline(model='fixie-ai/ultravox-v0_3', trust_remote_code=True)
path = "<path-to-input-audio>" # TODO: pass the audio here
audio, sr = librosa.load(path, sr=16000)
turns = [
{
"role": "system",
"content": "You are a friendly and helpful character. You love to answer questions for people."
},
]
pipe({'audio': audio, 'turns': turns, 'sampling_rate': sr}, max_new_tokens=30)
```
## Training Details
The model uses a pre-trained [Llama3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) backbone as well as the encoder part of [Whisper-small](https://huggingface.co/openai/whisper-small).
Only the multi-modal adapter is trained, while Whisper encoder and Llama are kept frozen.
We use a knowledge-distillation loss where Ultravox is trying to match the logits of the text-based Llama backbone.
### Training Data
Training dataset is a mix of ASR datasets, extended by adding a "continuation" generated by Llama 3.1 8B.
### Training Procedure
Supervised speech to audio finetuning. For more info, see [training code in Ultravox repo](https://github.com/fixie-ai/ultravox/blob/main/ultravox/training/train.py).
#### Training Hyperparameters
- **Training regime:** BF16 mixed precision training
- **Hardward used:** 8x H100 GPUs
#### Speeds, Sizes, Times
The current version of Ultravox, when invoked with audio content, has a time-to-first-token (TTFT) of approximately 200ms, and a tokens-per-second rate of ~50-100 when using an A100-40GB GPU, all using a Llama 3.1 8B backbone.
Check out the audio tab on [TheFastest.ai](https://thefastest.ai/?m=audio) for daily benchmarks and a comparison with other existing models.
## Evaluation
| | en_de (BLEU) | es_en (BLEU) | LibriSpeech clean.test (WER) |
|:------------------|:-------------|:-------------|:----------------------------|
| Ultravox v0.2 | 12.07 | 15.17 | 6.07 |
| **Ultravox v0.3** | 22.68 | 24.10 | 6.67 |
| Whisper-Llama3.1 | 24.89 | 28.67 | 3.4 |
| Llama3.1 (text-only) | 31.95 | 38.28 | - |

204
config.json Normal file
View File

@ -0,0 +1,204 @@
{
"_name_or_path": "/home/ubuntu/Disk/ultravox/artifacts/model-zhuang.2024-07-31-ultravox.blsp-kd-2a-v5",
"architectures": [
"UltravoxModel"
],
"audio_config": {
"_name_or_path": "openai/whisper-small",
"activation_dropout": 0.0,
"activation_function": "gelu",
"apply_spec_augment": false,
"architectures": [
"WhisperForConditionalGeneration"
],
"attention_dropout": 0.0,
"begin_suppress_tokens": [
220,
50257
],
"bos_token_id": 50257,
"d_model": 768,
"decoder_attention_heads": 12,
"decoder_ffn_dim": 3072,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 50258,
"dropout": 0.0,
"encoder_attention_heads": 12,
"encoder_ffn_dim": 3072,
"encoder_layerdrop": 0.0,
"encoder_layers": 12,
"eos_token_id": 50257,
"forced_decoder_ids": [
[
1,
50259
],
[
2,
50359
],
[
3,
50363
]
],
"init_std": 0.02,
"is_encoder_decoder": true,
"max_length": 448,
"max_source_positions": 1500,
"max_target_positions": 448,
"median_filter_width": 7,
"model_type": "whisper",
"num_hidden_layers": 12,
"num_mel_bins": 80,
"pad_token_id": 50257,
"scale_embedding": false,
"suppress_tokens": [
1,
2,
7,
8,
9,
10,
14,
25,
26,
27,
28,
29,
31,
58,
59,
60,
61,
62,
63,
90,
91,
92,
93,
359,
503,
522,
542,
873,
893,
902,
918,
922,
931,
1350,
1853,
1982,
2460,
2627,
3246,
3253,
3268,
3536,
3846,
3961,
4183,
4667,
6585,
6647,
7273,
9061,
9383,
10428,
10929,
11938,
12033,
12331,
12562,
13793,
14157,
14635,
15265,
15618,
16553,
16604,
18362,
18956,
20075,
21675,
22520,
26130,
26161,
26435,
28279,
29464,
31650,
32302,
32470,
36865,
42863,
47425,
49870,
50254,
50258,
50360,
50361,
50362
],
"torch_dtype": "float32",
"use_cache": true,
"vocab_size": 51865
},
"audio_model_id": "openai/whisper-small",
"audio_token_index": 32000,
"auto_map": {
"AutoConfig": "ultravox_config.UltravoxConfig",
"AutoModel": "ultravox_model.UltravoxModel",
"AutoProcessor": "ultravox_processing.UltravoxProcessor"
},
"custom_pipelines": {
"ultravox-pipeline": {
"impl": "ultravox_pipeline.UltravoxPipeline",
"pt": [
"AutoModel"
],
"tf": [],
"type": "multimodal"
}
},
"hidden_size": 4096,
"ignore_index": -100,
"initializer_range": 0.02,
"model_type": "ultravox",
"norm_init": 0.4,
"projector_act": "swiglu",
"stack_factor": 8,
"text_config": {
"_name_or_path": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 128000,
"eos_token_id": [
128001,
128008,
128009
],
"intermediate_size": 14336,
"max_position_embeddings": 131072,
"model_type": "llama",
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"torch_dtype": "bfloat16",
"vocab_size": 128256
},
"text_model_id": null,
"torch_dtype": "bfloat16",
"transformers_version": "4.43.2",
"vocab_size": 128256
}

1
configuration.json Normal file
View File

@ -0,0 +1 @@
{"framework": "pytorch", "task": "feature-extraction", "allow_remote": true}

11
generation_config.json Normal file
View File

@ -0,0 +1,11 @@
{
"_from_model_config": true,
"bos_token_id": 128000,
"eos_token_id": [
128001,
128008,
128009
],
"pad_token_id": 128009,
"transformers_version": "4.43.2"
}

BIN
model-00001-of-00004.safetensors (Stored with Git LFS) Normal file

Binary file not shown.

BIN
model-00002-of-00004.safetensors (Stored with Git LFS) Normal file

Binary file not shown.

BIN
model-00003-of-00004.safetensors (Stored with Git LFS) Normal file

Binary file not shown.

BIN
model-00004-of-00004.safetensors (Stored with Git LFS) Normal file

Binary file not shown.

View File

@ -0,0 +1,302 @@
{
"metadata": {
"total_size": 16127651840
},
"weight_map": {
"language_model.lm_head.weight": "model-00004-of-00004.safetensors",
"language_model.model.embed_tokens.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.18.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.19.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.19.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.19.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.19.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.20.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.20.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.20.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.30.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
"language_model.model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
"language_model.model.layers.31.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
"language_model.model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
"language_model.model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.31.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.31.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.31.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
"language_model.model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.8.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.8.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
"language_model.model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
"language_model.model.norm.weight": "model-00004-of-00004.safetensors",
"multi_modal_projector.linear_1.weight": "model-00001-of-00004.safetensors",
"multi_modal_projector.linear_2.weight": "model-00001-of-00004.safetensors",
"multi_modal_projector.ln_post.weight": "model-00001-of-00004.safetensors",
"multi_modal_projector.ln_pre.weight": "model-00001-of-00004.safetensors"
}
}

17
special_tokens_map.json Normal file
View File

@ -0,0 +1,17 @@
{
"bos_token": {
"content": "<|begin_of_text|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|eot_id|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": "<|eot_id|>"
}

410563
tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

2063
tokenizer_config.json Normal file

File diff suppressed because it is too large Load Diff

157
ultravox_config.py Normal file
View File

@ -0,0 +1,157 @@
import dataclasses
from enum import Enum
from typing import Any, Dict, List, Optional
import transformers
@dataclasses.dataclass
class LoraConfigSimplified:
"""
Low Rank Approximation (LoRA) configuration.
Used for language and audio models separately.
"""
# The rank of the approximation
r: int = 0
lora_alpha: float = 8
target_modules: Optional[List[str]] = dataclasses.field(
default_factory=lambda: ["k_proj", "q_proj", "linear_k", "linear_q"]
)
class LossFunction(str, Enum):
CrossEntropy = "ce"
KL_Divergence = "kl"
@dataclasses.dataclass
class LossConfig:
loss_function: LossFunction = LossFunction.KL_Divergence
kl_temperature: float = 2.0
@property
def requires_alt_fields(self):
return self.loss_function == LossFunction.KL_Divergence
class UltravoxConfig(transformers.PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`UltravoxForConditionalGeneration`]. It is used to instantiate an
Ultravox model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
audio_config (`Wav2Vec2Config`, *optional*):
Custom audio config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
audio_token_index (`int`, *optional*, defaults to 32000):
The audio token index to encode the audio prompt.
stack_factor (`int`, *optional*, defaults to 8):
Audio downsampling factor for the multimodal projector.
norm_init (`float`, *optional*, defaults to 0.4):
The initialization value for the layer normalization.
projector_act (`str`, *optional*, defaults to `"swiglu"`):
The activation function used by the multimodal projector.
text_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the text model.
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the audio model.
Example:
```python
>>> from transformers import UltravoxForConditionalGeneration, Wav2Vec2Config, UltravoxConfig, LlamaConfig
>>> # Initializing an audio encoder config
>>> audio_config = Wav2Vec2Config()
>>> # Initializing a Llama config
>>> text_config = LlamaConfig()
>>> # Initializing a default configuration
>>> configuration = UltravoxConfig(audio_config, text_config)
>>> # Initializing a completely untrained model from the configuration
>>> model = UltravoxForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # Initialize a model from pretrained checkpoints and random projector weights
>>> config = UltravoxConfig(audio_model_id="facebook/wav2vec2-base-960h", text_model_id="meta-llama/Llama-2-7b-chat-hf")
```"""
model_type = "ultravox"
is_composition = False
def __init__(
self,
audio_config: Optional[Dict[str, Any]] = None,
text_config: Optional[Dict[str, Any]] = None,
audio_model_id: Optional[str] = None,
text_model_id: Optional[str] = None,
ignore_index: int = -100,
audio_token_index: int = 32000,
hidden_size: int = 4096,
stack_factor: int = 8,
norm_init: float = 0.4,
projector_act: str = "swiglu",
text_model_lora_config: Optional[LoraConfigSimplified] = None,
audio_model_lora_config: Optional[LoraConfigSimplified] = None,
**kwargs,
):
self.ignore_index = ignore_index
self.audio_model_id = audio_model_id
self.text_model_id = text_model_id
self.audio_token_index = audio_token_index
self.hidden_size = hidden_size
self.stack_factor = stack_factor
self.norm_init = norm_init
self.projector_act = projector_act
if text_model_id is not None:
self.text_config: transformers.LlamaConfig = (
transformers.AutoConfig.from_pretrained(text_model_id)
)
else:
text_config = text_config or {}
self.text_config = transformers.CONFIG_MAPPING[
text_config.get("model_type", "llama")
](**text_config)
if audio_model_id is not None:
self.audio_config: transformers.PretrainedConfig = (
transformers.AutoConfig.from_pretrained(audio_model_id)
)
else:
audio_config = audio_config or {}
self.audio_config = transformers.CONFIG_MAPPING[
audio_config.get("model_type", "wav2vec2")
](**audio_config)
self.text_model_lora_config = (
text_model_lora_config
if isinstance(text_model_lora_config, dict)
else dataclasses.asdict(text_model_lora_config or LoraConfigSimplified())
)
self.audio_model_lora_config = (
audio_model_lora_config
if isinstance(audio_model_lora_config, dict)
else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
)
self.vocab_size = self.text_config.vocab_size
self.initializer_range = self.text_config.initializer_range
super().__init__(**kwargs)

504
ultravox_model.py Normal file
View File

@ -0,0 +1,504 @@
import logging
from typing import Any, Dict, Optional, Set, Tuple, Union
import peft
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import transformers.activations
import transformers.modeling_outputs
import transformers.models
# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .ultravox_config import LossConfig
from .ultravox_config import LossFunction
from .ultravox_config import UltravoxConfig
from .whisper_model_modified import WhisperEncoder as ModifiedWhisperEncoder
class UltravoxModel(transformers.LlamaPreTrainedModel):
"""
The Ultravox model which consists of an audio encoder and a language model.
Audio input is processed by the audio encoder, then every `stack_factor` frames are stacked together and
projected to the language model's embedding space using a few linear layers.
The text is embedded by the language model as usual and then the audio and text embeddings are merged together.
A special token `<|audio|>` is used to indicate the start of the audio embeddings in the merged embeddings.
Parameters:
config: Model configuration class with all the parameters of the model.
"""
config_class = UltravoxConfig
config: UltravoxConfig # for type hinting
_no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
# We minimize the weights in state_dict in order to reduce the size of the checkpoint
# The issue is that load_pretrained() uses state_dict() keys to know what keys are expected
# As such we have to tell is to ignore some keys that are not always in the model
_keys_to_ignore_on_load_unexpected = ["audio_tower.*", "language_model.*"]
# Usually we load encoder weights from a pretrained model, so we don't want to load the decoder weights
# Technically we never hit this issue because these keys are already removed from state_dict() however,
# but there's no harm in keeping it here for when we change that behavior.
_keys_to_ignore_on_load_missing = ["audio_tower.*"]
def __init__(self, config: UltravoxConfig):
super().__init__(config)
self.keep_params: Set[str] = set()
self.vocab_size = config.vocab_size
self.audio_tower = self._create_audio_tower(config)
self.multi_modal_projector = UltravoxProjector(config)
self.language_model = self._create_language_model(config)
self.loss_config = LossConfig()
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def set_loss_config(self, loss_config: LossConfig):
self.loss_config = loss_config
def _setup_cache(
self, cache_cls, max_batch_size: int, max_cache_len: Optional[int] = None
):
self.language_model._setup_cache(cache_cls, max_batch_size, max_cache_len)
def _reorder_cache(self, past_key_values, beam_idx):
return self.language_model._reorder_cache(past_key_values, beam_idx)
def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(
new_num_tokens, pad_to_multiple_of
)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def _compute_kl_loss(
self,
lm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
alt_input_ids: Optional[torch.Tensor] = None,
alt_attention_mask: Optional[torch.Tensor] = None,
alt_labels: Optional[torch.Tensor] = None,
**kwargs,
):
# disable gradient computation for the teacher model
with torch.no_grad():
# compute the teacher (text-only) model's distribution
alt_inputs_embeds = self.get_input_embeddings().forward(alt_input_ids)
alt_lm_output = self.language_model.forward(
inputs_embeds=alt_inputs_embeds,
labels=alt_labels,
attention_mask=alt_attention_mask,
past_key_values=past_key_values,
**kwargs,
)
# compute the KL divergence loss between the two models
kl_loss = F.kl_div(
F.log_softmax(
lm_output.logits[labels != -100] / self.loss_config.kl_temperature,
dim=-1,
),
F.softmax(
alt_lm_output.logits[alt_labels != -100]
/ self.loss_config.kl_temperature,
dim=-1,
),
reduction="batchmean",
)
return {"loss": kl_loss}
def forward(
self,
input_ids: torch.Tensor,
audio_values: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
audio_token_start_idx: Optional[torch.Tensor] = None,
audio_token_len: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
# the alt_* fields are needed for KL divergence loss
alt_input_ids: Optional[torch.Tensor] = None,
alt_attention_mask: Optional[torch.Tensor] = None,
alt_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, transformers.modeling_outputs.CausalLMOutputWithPast]:
"""
Forward pass for the Ultravox model.
`input_ids` are the tokenized text input. They are embedded by the language model as usual.
`audio_values` are processed by the audio encoder and then every `stack_factor` frames are stacked together and
projected to the language model's embedding space using a few linear layers.
The audio and text embeddings are merged together. A special token `<|audio|>` is used to indicate the start
of the audio embeddings in the merged embeddings.
Args:
input_ids: The tokenized text input.
audio_values: The processed audio values.
inputs_embeds: The embeddings for the input tokens.
labels: The tokenized text labels.
attention_mask: The attention mask for the input.
position_ids: The position ids for the input.
past_key_values: The past key value cache for the language model attention layers.
**kwargs: Additional keyword arguments. Passed directly to the language model.
"""
if inputs_embeds is None:
# B x T -> B x T x D
inputs_embeds = self.get_input_embeddings().forward(input_ids)
if audio_values is not None:
assert (
audio_token_start_idx is not None and audio_token_len is not None
), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
assert (
len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
# B x A/3200 x D
audio_tower_output = self.audio_tower.forward(
audio_values
).last_hidden_state
audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
# combine audio and text embeddings
for i, (audio, start, length) in enumerate(
zip(audio_embeds, audio_token_start_idx, audio_token_len)
):
length = min(length, audio.shape[0])
inputs_embeds[i, start : start + length] = audio[:length]
lm_output = self.language_model.forward(
inputs_embeds=inputs_embeds,
labels=labels,
attention_mask=attention_mask,
past_key_values=past_key_values,
**kwargs,
)
if self.training:
if self.loss_config.loss_function == LossFunction.CrossEntropy:
return lm_output
elif self.loss_config.loss_function == LossFunction.KL_Divergence:
return self._compute_kl_loss(
lm_output=lm_output,
labels=labels,
past_key_values=past_key_values,
alt_input_ids=alt_input_ids,
alt_attention_mask=alt_attention_mask,
alt_labels=alt_labels,
**kwargs,
)
else:
raise ValueError(
f"Unsupported loss function: {self.loss_config.loss_function}"
)
else:
return lm_output
def prepare_inputs_for_generation(
self,
input_ids: torch.Tensor,
audio_values: Optional[torch.FloatTensor] = None,
audio_token_start_idx: Optional[torch.Tensor] = None,
audio_token_len: Optional[torch.Tensor] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, Any]:
model_input = self.language_model.prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs,
)
if is_cache_empty(past_key_values) and audio_values is not None:
# We only want to use audio features in the 1st generation step
model_input["audio_values"] = audio_values
model_input["audio_token_start_idx"] = audio_token_start_idx
model_input["audio_token_len"] = audio_token_len
return model_input
@classmethod
def _create_audio_tower(
cls, config: UltravoxConfig
) -> Union[transformers.Wav2Vec2Model, ModifiedWhisperEncoder]:
if config.audio_model_id is not None:
if "whisper" in config.audio_model_id is not None:
audio_tower = ModifiedWhisperEncoder.from_pretrained(
config.audio_model_id
)
else:
audio_tower = transformers.AutoModel.from_pretrained(
config.audio_model_id
)
else:
if "whisper" in config.audio_config._name_or_path:
audio_tower = ModifiedWhisperEncoder(config.audio_config)
else:
with transformers.modeling_utils.no_init_weights():
# we only ever use from_config if the weights are retrained, hence initializing is not
# required. This makes the model quite creation faster since init on CPU is quite slow.
audio_tower = transformers.AutoModel.from_config(
config.audio_config
)
if isinstance(
audio_tower,
(transformers.Wav2Vec2BertModel, transformers.WhisperModel),
):
# For these models we only need the encoder part
# Wav2Vec2BertModel -> Wav2Vec2BertEncoder
# WhisperModel -> WhisperEncoder
audio_tower = audio_tower.encoder
audio_tower = apply_lora(audio_tower, config.audio_model_lora_config)
return audio_tower
@classmethod
def _create_language_model(
cls, config: UltravoxConfig
) -> transformers.LlamaForCausalLM:
if config.text_model_id is not None:
language_model = transformers.AutoModelForCausalLM.from_pretrained(
config.text_model_id, attn_implementation=config._attn_implementation
)
else:
with transformers.modeling_utils.no_init_weights():
# we only ever use from_config if the weights are retrained, hence initializing is not
# required. This makes the model quite creation faster since init on CPU is quite slow.
language_model = transformers.AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = apply_lora(language_model, config.text_model_lora_config)
return language_model
def _add_language_model_weights_to_keep(self):
if self.config.text_model_id is not None:
self.config.text_model_id = None
self.keep_params.update(
set(
[
f"language_model.{name}"
for name, _ in self.language_model.named_parameters()
]
)
)
def _add_audio_tower_weights_to_keep(self):
if self.config.audio_model_id is not None:
self.config.audio_model_id = None
self.keep_params.update(
set(
[
f"audio_tower.{name}"
for name, _ in self.audio_tower.named_parameters()
]
)
)
def merge_and_unload(self):
if isinstance(self.language_model, peft.PeftModel):
self.language_model = self.language_model.merge_and_unload()
# no need to download base language model weights anymore, so we can remove the id
self._add_language_model_weights_to_keep()
if isinstance(self.audio_tower, peft.PeftModel):
self.audio_tower = self.audio_tower.merge_and_unload()
# no need to download base audio model weights anymore, so we can remove the id
self._add_audio_tower_weights_to_keep()
for param in ["text_model_lora_config", "audio_model_lora_config"]:
if hasattr(self.config, param):
delattr(self.config, param)
def push_to_hub(self, *args, **kwargs):
self.merge_and_unload()
self.to(self.language_model.dtype)
return super().push_to_hub(*args, **kwargs)
def state_dict(self, *args, **kwargs):
named_params = dict(self.named_parameters())
state_dict = super().state_dict(*args, **kwargs)
state_dict = {
k: v
for k, v in state_dict.items()
if k in self.keep_params
or (k in named_params and named_params[k].requires_grad)
}
return state_dict
def load_state_dict(
self,
state_dict: Dict[str, Any],
*args,
**kwargs,
):
self.keep_params.update(set(state_dict.keys()))
return super().load_state_dict(state_dict, *args, **kwargs)
def print_trainable_parameters(self):
"""
Prints the number of trainable parameters in the model (reuses Peft model's method)
"""
count_params = peft.peft_model.PeftModel.get_nb_trainable_parameters
trainable_params, all_param = count_params(self)
logging.info(
f"trainable params: {trainable_params:,d} || all params: {all_param:,d}"
f" || trainable%: {100 * trainable_params / all_param:.1f}%"
)
lm_trainable_params, lm_all_params = count_params(self.language_model)
audio_trainable_params, audio_all_params = count_params(self.audio_tower)
projector_trainable_params = (
trainable_params - lm_trainable_params - audio_trainable_params
)
projector_all_params = all_param - lm_all_params - audio_all_params
logging.info(
f"Trainable%: "
f" LLM: {100 * lm_trainable_params / lm_all_params:.1f}%"
f" || Audio Encoder: {100 * audio_trainable_params / audio_all_params:.1f}%"
f" || Projector: {100 * projector_trainable_params / projector_all_params:.1f}%"
)
def is_cache_empty(
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]]
) -> bool:
"""
Check if the cache is empty.
"""
if past_key_values is None:
return True
if isinstance(past_key_values, tuple):
return all(len(c) == 0 for c in past_key_values)
return past_key_values.get_seq_length() == 0
def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
"""
Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
"""
lora_config = peft.LoraConfig(**lora_config or {})
if lora_config.r == 0:
# freeze the model entirely
for param in model.parameters():
param.requires_grad = False
else:
model = peft.get_peft_model(model, lora_config)
return model
class StackAudioFrames(nn.Module):
"""
Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`.
The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames.
NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor,
we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings.
In most cases this extra padding will get removed in the model's forward function so it has no effect.
"""
def __init__(self, stack_factor: int = 8):
super().__init__()
self.stack_factor = stack_factor
def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
B, T, C = audio_embeds.shape
T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor))
B, T, C = audio_embeds.shape
audio_embeds = audio_embeds.view(
B, T // self.stack_factor, C * self.stack_factor
)
return audio_embeds
class RMSNorm(transformers.models.llama.modeling_llama.LlamaRMSNorm):
def __init__(self, hidden_size: int, init: float = 1, eps: float = 1e-6):
super().__init__(hidden_size=hidden_size, eps=eps)
self.weight.data.fill_(init)
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
class UltravoxProjector(nn.Sequential):
def __init__(self, config: UltravoxConfig):
super().__init__()
self.hidden_dim = config.hidden_size
self._pad_and_stack = StackAudioFrames(config.stack_factor)
dim = config.audio_config.hidden_size * config.stack_factor
self.ln_pre = RMSNorm(dim, init=config.norm_init)
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
dim = self.hidden_dim
self.act = transformers.activations.get_activation(config.projector_act)
dim = dim // 2 if config.projector_act == "swiglu" else dim
self.linear_2 = nn.Linear(dim, config.text_config.hidden_size, bias=False)
self.ln_post = RMSNorm(config.text_config.hidden_size, init=config.norm_init)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
audio_features = self.ln_pre(audio_features)
hidden_states = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.ln_post(hidden_states)
return hidden_states
UltravoxConfig.register_for_auto_class()
UltravoxModel.register_for_auto_class()
transformers.AutoConfig.register("ultravox", UltravoxConfig)
transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
# transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor) # TODO: make processor work standalone
transformers.activations.ACT2FN["swiglu"] = SwiGLU

127
ultravox_pipeline.py Normal file
View File

@ -0,0 +1,127 @@
import logging
from typing import Any, Dict, List, Optional
import numpy as np
import transformers
# We must use relative import in this directory to allow uploading to HF Hub
# Even "from . import X" pattern doesn't work (undocumented and unclear why)
from .ultravox_model import UltravoxModel
from .ultravox_processing import UltravoxProcessor
class UltravoxPipeline(transformers.Pipeline):
def __init__(
self,
model: UltravoxModel,
tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
audio_processor: Optional[transformers.ProcessorMixin] = None,
**kwargs
):
if tokenizer is None:
try:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model.config._name_or_path
)
except:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model.config.text_model_id or model.config.text_config._name_or_path
)
if audio_processor is None:
audio_processor = transformers.AutoProcessor.from_pretrained(
model.config.audio_model_id or model.config.audio_config._name_or_path
)
super().__init__(model=model, tokenizer=tokenizer, **kwargs)
self.processor = UltravoxProcessor(
audio_processor=audio_processor,
tokenizer=tokenizer,
stack_factor=model.config.stack_factor,
)
def _sanitize_parameters(self, **kwargs):
generation_keys = ["temperature", "max_new_tokens", "repetition_penalty"]
generation_kwargs = {k: kwargs[k] for k in kwargs if k in generation_keys}
return {}, generation_kwargs, {}
def preprocess(self, inputs: Dict[str, Any]):
turns: list = inputs.get("turns", [])
audio = inputs.get("audio", None)
# Convert to float32 if needed.
if isinstance(audio, np.ndarray):
if audio.dtype == np.float64:
audio = audio.astype(np.float32)
elif audio.dtype == np.int16:
audio = audio.astype(np.float32) / np.float32(32768.0)
elif audio.dtype == np.int32:
audio = audio.astype(np.float32) / np.float32(2147483648.0)
if audio is not None and (len(turns) == 0 or turns[-1]["role"] != "user"):
prompt = inputs.get("prompt", "<|audio|>")
if "<|audio|>" not in prompt:
logging.warning(
"Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
)
prompt += " <|audio|>"
turns.append({"role": "user", "content": prompt})
text = self.processor.tokenizer.apply_chat_template(
turns, add_generation_prompt=True, tokenize=False
)
if "sampling_rate" not in inputs and audio is not None:
logging.warning(
"No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
)
output = self.processor(
text=text,
audio=audio,
sampling_rate=inputs.get("sampling_rate", 16000),
)
if "audio_values" in output:
output["audio_values"] = output["audio_values"].to(self.model.dtype)
return output
def _forward(
self,
model_inputs: Dict[str, Any],
temperature: Optional[float] = None,
max_new_tokens: Optional[int] = None,
repetition_penalty: float = 1.1,
) -> List[int]:
temperature = temperature or None
do_sample = temperature is not None
terminators = [self.tokenizer.eos_token_id]
if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
input_len = model_inputs["input_ids"].shape[1]
outputs = self.model.generate(
**model_inputs,
do_sample=do_sample,
temperature=temperature,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
eos_token_id=terminators
)
return outputs[0][input_len:]
def postprocess(self, model_outputs) -> str:
output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
return output_text
transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
"ultravox-pipeline",
pipeline_class=UltravoxPipeline,
pt_model=transformers.AutoModel,
type="multimodal",
)

205
ultravox_processing.py Normal file
View File

@ -0,0 +1,205 @@
from typing import Optional, Union, Dict, Any
import numpy as np
import torch
import transformers
from .ultravox_config import UltravoxConfig
class UltravoxProcessor(transformers.ProcessorMixin):
"""
Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
Args:
audio_processor: The audio processor for the audio encoder.
tokenizer: The tokenizer for the language model.
"""
attributes = ["audio_processor", "tokenizer"]
audio_processor_class = (
"Wav2Vec2Processor",
"SeamlessM4TFeatureExtractor",
"WhisperProcessor",
)
tokenizer_class = (
"PreTrainedTokenizer",
"PreTrainedTokenizerFast",
)
tokenizer: transformers.PreTrainedTokenizerBase
audio_processor: transformers.ProcessorMixin
def __init__(
self,
audio_processor=None,
tokenizer=None,
audio_padding: str = "longest",
encoder_ds_factor: int = 320,
stack_factor: int = 8,
audio_placeholder: str = "<|audio|>",
):
"""
Args:
audio_processor: The audio processor for the audio encoder.
tokenizer: The tokenizer for the language model.
audio_padding: The padding strategy for the audio encoder.
encoder_ds_factor: The downsample factor of the audio encoder.
stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
audio_placeholder: The placeholder for the audio in the text.
"""
self.audio_padding = audio_padding
self.encoder_ds_factor = encoder_ds_factor
self.stack_factor = stack_factor
self.audio_placeholder = audio_placeholder
self.audio_token_replacement = tokenizer.eos_token
assert (
self.audio_token_replacement is not None
), "The tokenizer has no EOS token. Cannot recover."
super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
audio_processor = transformers.AutoProcessor.from_pretrained(
config.audio_model_id
or config.audio_config._name_or_path
or "facebook/wav2vec2-base-960h"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
return cls(
audio_processor=audio_processor,
tokenizer=tokenizer,
stack_factor=config.stack_factor,
)
def __call__(
self,
text: Optional[str] = None,
audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
sampling_rate: Optional[int] = None,
return_tensors: Optional[
Union[str, transformers.TensorType]
] = transformers.TensorType.PYTORCH,
**kwargs,
) -> transformers.BatchFeature:
"""
Main method to prepare for the model one text sequence and audio. This method forwards the `text`
and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
of the above two methods for more information.
Args:
text (`str`, `List[str]`):
The sequence to be encoded. Sequence can be a string or (pretokenized string).
audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The audio to be prepared. Audio can be NumPy array or PyTorch tensor. In case of a
NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the
sample length of the audio.
sampling_rate (`int`, *optional*, defaults to 16000):
Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
you are doing.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **audio_values** -- Processed audio values to be fed to a model. Returned when `audio` is not `None`.
- **audio_token_len** -- Predicted number of audio frames: this value is guaranteed to be a close upper bound.
Returned when `audio` is not `None`.
- **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
"""
# TODO: Add support for multiple audio and text inputs.
data = {}
audio_embed_frames = 0
if audio is not None and len(audio) > 0:
if self.audio_padding == "max_length":
# 30 seconds is the expected length for Whisper
assert sampling_rate is not None, "Sampling rate must be provided."
audio_len = 30 * sampling_rate
else:
audio_len = audio.shape[-1]
# It's guaranteed that the number of frames is less than or equal to this amount.
# For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
# Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
data["audio_token_len"] = [audio_embed_frames]
# Main audio processing. The processor is model-specific.
x = self.audio_processor(
audio,
sampling_rate=sampling_rate,
padding="longest",
max_length=audio_len,
**kwargs,
)
if "input_features" in x:
data["audio_values"] = x.input_features
else:
data["audio_values"] = x.input_values
if text is not None:
assert isinstance(
text, str
), "Text must be a string. Batch mode not supported yet."
if self.audio_placeholder in text:
if "audio_token_len" not in data:
raise ValueError(
f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
)
start_idx = len(
self.tokenizer.encode(
text[: text.index(self.audio_placeholder)],
add_special_tokens=False,
)
)
data["audio_token_start_idx"] = [start_idx]
# Replace the audio placeholder with the audio token.
# e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
# where the number of </s> is the number of audio frames.
text = text.replace(
self.audio_placeholder,
self.audio_token_replacement * audio_embed_frames,
)
# Special tokens like BOS should already have been added by the caller.
data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
return transformers.BatchFeature(data=data, tensor_type=return_tensors)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
audio_processor_input_names = self.audio_processor.model_input_names
return list(set(tokenizer_input_names + audio_processor_input_names))
transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)

141
whisper_model_modified.py Normal file
View File

@ -0,0 +1,141 @@
# modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
# see this issue for the commentary: https://github.com/huggingface/transformers/issues/25744
#
# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import transformers
import transformers.modeling_outputs
from transformers.models.whisper import modeling_whisper as whisper
class WhisperEncoder(whisper.WhisperEncoder):
"""
Encoder portion of OpenAI's Whisper model.
This implementation is a slightly modified version of HF Transformers' Whisper Encoder, with only a few fixes:
1. base_model_prefix updated to allow for doing `.from_pretrained` directly on the encoder
2. allow less than 30 second of audio padding to be passed in:
- relaxed ValueError check for `input_features` length to be less than or equal to `expected_seq_length` instead of strictly equal
- embed_pos is now sliced to match the length of `inputs_embeds`
Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
"""
base_model_prefix = "model.encoder"
def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
expected_seq_length = (
self.config.max_source_positions
* self.conv1.stride[0]
* self.conv2.stride[0]
)
if input_features.shape[-1] > expected_seq_length:
raise ValueError(
f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
)
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
None,
layer_head_mask=(
head_mask[idx] if head_mask is not None else None
),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, encoder_states, all_attentions]
if v is not None
)
return transformers.modeling_outputs.BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)