first commit

This commit is contained in:
xxl 2025-01-10 12:57:39 +08:00
parent 385432bacf
commit 175f9431f2
9 changed files with 833 additions and 2 deletions

51
LICENSE.md Normal file
View File

@ -0,0 +1,51 @@
# LICENSE
## 1. Model & License Summary
This repository contains **TangoFlux** (the “Model”) created for **non-commercial, research-only** purposes under the **UK data copyright exemption**. The Model is subject to:
1. The **Stability AI Community License Agreement**, provided in the file ```STABILITY_AI_COMMUNITY_LICENSE.md```.
2. The **WavCaps** license requirement: **only academic uses** are permitted for data sourced from WavCaps.
3. The **original licenses** of the datasets used in training.
By using or distributing this Model, you **agree** to adhere to all applicable licenses and restrictions, as summarized below.
---
## 2. Stability AI Community License Requirements
- You must comply with the **Stability AI Community License Agreement** (the “Agreement”) for any usage, distribution, or modification of this Model.
- **Non-Commercial Use**: This Model is for research and academic purposes only. Any commercial usage requires registering with Stability AI or obtaining a separate commercial license.
- **Attribution & Notice**:
- Retain the notice:
```
This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved.
```
- Clearly display “Powered by Stability AI” if you build upon or showcase this Model.
- **Disclaimer & Liability**: This Model is provided **“AS IS”** with **no warranties**. Neither we nor Stability AI will be liable for any claim or damages related to Model use.
See ```STABILITY_AI_COMMUNITY_LICENSE.md``` for the full text.
---
## 3. WavCaps & Dataset Usage
- **Academic-Only for WavCaps**: By accessing any WavCaps-sourced data (including audio clips via provided links), you agree to use them **strictly for non-commercial, academic research** in accordance with WavCaps terms.
- **WavCaps Audio**: Each WavCaps audio subset has its own license terms. **You** are responsible for reviewing and complying with those licenses, including attribution requirements on your end.
---
## 4. UK Data Copyright Exemption
This Model was developed under the **UK data copyright exemption for non-commercial research**. Distribution or use outside these bounds must **not** violate that exemption or infringe on any underlying datasets license.
---
## 5. Further Information
- **Stability AI License Terms**: <https://stability.ai/community-license>
- **WavCaps License**: <https://github.com/XinhaoMei/WavCaps?tab=readme-ov-file#license>
---
**End of License**.

View File

@ -1,3 +1,95 @@
# TangoFlux
---
datasets:
- cvssp/WavCaps
cite: arxiv.org/abs/2412.21037
pipeline_tag: text-to-audio
---
TangoFlux
<h1 align="center">
<br/>
TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization
<br/>
✨✨✨
</h1>
<div align="center">
<img src="https://raw.githubusercontent.com/declare-lab/TangoFlux/refs/heads/main/assets/tf_teaser.png" alt="TangoFlux" width="1000" />
<br/>
<div style="display: flex; gap: 10px; align-items: center;">
<a href="https://openreview.net/attachment?id=tpJPlFTyxd&name=pdf">
<img src="https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf" alt="arXiv">
</a>
<a href="https://huggingface.co/declare-lab/TangoFlux">
<img src="https://img.shields.io/badge/TangoFlux-Huggingface-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
</a>
<a href="https://tangoflux.github.io/">
<img src="https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat" alt="Static Badge">
</a>
<a href="https://huggingface.co/spaces/declare-lab/TangoFlux">
<img src="https://img.shields.io/badge/TangoFlux-Huggingface_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
</a>
<a href="https://huggingface.co/datasets/declare-lab/CRPO">
<img src="https://img.shields.io/badge/TangoFlux_Dataset-Huggingface-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
</a>
<a href="https://github.com/declare-lab/TangoFlux">
<img src="https://img.shields.io/badge/Github-brown?logo=github&link=https%3A%2F%2Fgithub.com%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
</a>
</div>
</div>
* Powered by **Stability AI**
## Model Overview
TangoFlux consists of FluxTransformer blocks which are Diffusion Transformer (DiT) and Multimodal Diffusion Transformer (MMDiT), conditioned on textual prompt and duration embedding to generate audio at 44.1kHz up to 30 seconds. TangoFlux learns a rectified flow trajectory from audio latent representation encoded by a variational autoencoder (VAE). The TangoFlux training pipeline consists of three stages: pre-training, fine-tuning, and preference optimization. TangoFlux is aligned via CRPO which iteratively generates new synthetic data and constructs preference pairs to perform preference optimization.
## Getting Started
Get TangoFlux from our GitHub repo https://github.com/declare-lab/TangoFlux with
```bash
pip install git+https://github.com/declare-lab/TangoFlux
```
The model will be automatically downloaded and saved in a cache. The subsequent runs will load the model directly from the cache.
The `generate` function uses 25 steps by default to sample from the flow model. We recommend using 50 steps for generating better quality audios. This comes at the cost of increased run-time.
```python
import torchaudio
from tangoflux import TangoFluxInference
from IPython.display import Audio
model = TangoFluxInference(name='declare-lab/TangoFlux')
audio = model.generate('Hammer slowly hitting the wooden table', steps=50, duration=10)
Audio(data=audio, rate=44100)
```
## License
The TangoFlux checkpoints are for non-commercial research use only. They are subject to the [Stable Audio Opens license](https://huggingface.co/stabilityai/stable-audio-open-1.0/blob/main/LICENSE.md), [WavCaps license](https://github.com/XinhaoMei/WavCaps?tab=readme-ov-file#license), and the original licenses accompanying each training dataset.
This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved
## Citation
https://arxiv.org/abs/2412.21037
```bibtex
@misc{hung2024tangofluxsuperfastfaithful,
title={TangoFlux: Super Fast and Faithful Text to Audio Generation with Flow Matching and Clap-Ranked Preference Optimization},
author={Chia-Yu Hung and Navonil Majumder and Zhifeng Kong and Ambuj Mehrish and Rafael Valle and Bryan Catanzaro and Soujanya Poria},
year={2024},
eprint={2412.21037},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2412.21037},
}
```

View File

@ -0,0 +1,58 @@
STABILITY AI COMMUNITY LICENSE AGREEMENT
Last Updated: July 5, 2024
1. INTRODUCTION
This Agreement applies to any individual person or entity (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below.
This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent).
By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then “You” includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entitys behalf.
2. RESEARCH & NON-COMMERCIAL USE LICENSE
Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AIs intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. “Research Purpose” means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. “Non-Commercial Purpose” means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing.
3. COMMERCIAL USE LICENSE
Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AIs intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. “Commercial Purpose” means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your businesss or organizations internal operations.
If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you.
4. GENERAL TERMS
Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms.
a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved”, and (iii) prominently display “Powered by Stability AI” on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the “Notice” text file that You changed the Stability AI Materials and how it was modified.
b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AIs AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works).
c. Intellectual Property.
(i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein.
(ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AIs ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
(iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law.
(iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement.
(v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AIs existing or prospective technology, products or services (collectively, “Feedback”). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided “AS IS” and You make no warranties whatsoever about any Feedback.
d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS.
e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement.
g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement.
5. DEFINITIONS
“Affiliate(s)” means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, “control” means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity.
"Agreement" means this Stability AI Community License Agreement.
“AUP” means the Stability AI Acceptable Use Policy available at (https://stability.ai/use-policy), as may be updated from time to time.
"Derivative Work(s)” means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Models output, including “fine tune” and “low-rank adaptation” models derived from a Model or a Models output, but do not include the output of any Model.
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models.
“Model(s)" means, collectively, Stability AIs proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stabilitys Core Models Webpage available at (https://stability.ai/core-models), as may be updated from time to time.
"Stability AI" or "we" means Stability AI Ltd. and its Affiliates.
"Software" means Stability AIs proprietary software made available under this Agreement now or in the future.
“Stability AI Materials” means, collectively, Stabilitys proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement.
“Trade Control Laws” means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.

1
config.json Normal file
View File

@ -0,0 +1 @@
{"num_layers": 6, "num_single_layers": 18, "in_channels": 64, "attention_head_dim": 128, "joint_attention_dim": 1024, "num_attention_heads": 8, "audio_seq_len": 645, "max_duration": 30, "uncondition": false, "text_encoder_name": "google/flan-t5-large"}

44
handler.py Normal file
View File

@ -0,0 +1,44 @@
from typing import Dict, List, Any
from tangoflux import TangoFluxInference
import torchaudio
#from huggingface_inference_toolkit.logging import logger
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
self.model = TangoFluxInference(name='declare-lab/TangoFlux',device='cuda')
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
logger.info(f"Received incoming request with {data=}")
if "inputs" in data and isinstance(data["inputs"], str):
prompt = data.pop("inputs")
elif "prompt" in data and isinstance(data["prompt"], str):
prompt = data.pop("prompt")
else:
raise ValueError(
"Provided input body must contain either the key `inputs` or `prompt` with the"
" prompt to use for the audio generation, and it needs to be a non-empty string."
)
parameters = data.pop("parameters", {})
num_inference_steps = parameters.get("num_inference_steps", 50)
duration = parameters.get("duration", 10)
guidance_scale = parameters.get("guidance_scale", 3.5)
return self.model.generate(prompt,steps=num_inference_steps,
duration=duration,
guidance_scale=guidance_scale)

510
model.py Normal file
View File

@ -0,0 +1,510 @@
from transformers import T5EncoderModel,T5TokenizerFast
import torch
from diffusers import FluxTransformer2DModel
from torch import nn
from typing import List
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_density_for_timestep_sampling
import copy
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from typing import Optional,Union,List
from datasets import load_dataset, Audio
from math import pi
import inspect
import yaml
class StableAudioPositionalEmbedding(nn.Module):
"""Used for continuous time
Adapted from stable audio open.
"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, times: torch.Tensor) -> torch.Tensor:
times = times[..., None]
freqs = times * self.weights[None] * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((times, fouriered), dim=-1)
return fouriered
class DurationEmbedder(nn.Module):
"""
A simple linear projection model to map numbers to a latent space.
Code is adapted from
https://github.com/Stability-AI/stable-audio-tools
Args:
number_embedding_dim (`int`):
Dimensionality of the number embeddings.
min_value (`int`):
The minimum value of the seconds number conditioning modules.
max_value (`int`):
The maximum value of the seconds number conditioning modules
internal_dim (`int`):
Dimensionality of the intermediate number hidden states.
"""
def __init__(
self,
number_embedding_dim,
min_value,
max_value,
internal_dim: Optional[int] = 256,
):
super().__init__()
self.time_positional_embedding = nn.Sequential(
StableAudioPositionalEmbedding(internal_dim),
nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim),
)
self.number_embedding_dim = number_embedding_dim
self.min_value = min_value
self.max_value = max_value
self.dtype = torch.float32
def forward(
self,
floats: torch.Tensor,
):
floats = floats.clamp(self.min_value, self.max_value)
normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value)
# Cast floats to same type as embedder
embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
normalized_floats = normalized_floats.to(embedder_dtype)
embedding = self.time_positional_embedding(normalized_floats)
float_embeds = embedding.view(-1, 1, self.number_embedding_dim)
return float_embeds
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class TangoFlux(nn.Module):
def __init__(self,config,initialize_reference_model=False):
super().__init__()
self.num_layers = config.get('num_layers', 6)
self.num_single_layers = config.get('num_single_layers', 18)
self.in_channels = config.get('in_channels', 64)
self.attention_head_dim = config.get('attention_head_dim', 128)
self.joint_attention_dim = config.get('joint_attention_dim', 1024)
self.num_attention_heads = config.get('num_attention_heads', 8)
self.audio_seq_len = config.get('audio_seq_len', 645)
self.max_duration = config.get('max_duration', 30)
self.uncondition = config.get('uncondition', False)
self.text_encoder_name = config.get('text_encoder_name', "google/flan-t5-large")
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
self.max_text_seq_len = 64
self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
self.tokenizer = T5TokenizerFast.from_pretrained(self.text_encoder_name)
self.text_embedding_dim = self.text_encoder.config.d_model
self.fc = nn.Sequential(nn.Linear(self.text_embedding_dim,self.joint_attention_dim),nn.ReLU())
self.duration_emebdder = DurationEmbedder(self.text_embedding_dim,min_value=0,max_value=self.max_duration)
self.transformer = FluxTransformer2DModel(
in_channels=self.in_channels,
num_layers=self.num_layers,
num_single_layers=self.num_single_layers,
attention_head_dim=self.attention_head_dim,
num_attention_heads=self.num_attention_heads,
joint_attention_dim=self.joint_attention_dim,
pooled_projection_dim=self.text_embedding_dim,
guidance_embeds=False)
self.beta_dpo = 2000 ## this is used for dpo training
def get_sigmas(self,timesteps, n_dim=3, dtype=torch.float32):
device = self.text_encoder.device
sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1):
device = self.text_encoder.device
batch = self.tokenizer(
prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
)
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
with torch.no_grad():
prompt_embeds = self.text_encoder(
input_ids=input_ids, attention_mask=attention_mask
)[0]
prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
# get unconditional embeddings for classifier free guidance
uncond_tokens = [""]
max_length = prompt_embeds.shape[1]
uncond_batch = self.tokenizer(
uncond_tokens, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt",
)
uncond_input_ids = uncond_batch.input_ids.to(device)
uncond_attention_mask = uncond_batch.attention_mask.to(device)
with torch.no_grad():
negative_prompt_embeds = self.text_encoder(
input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
)[0]
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
# For classifier free guidance, we need to do two forward passes.
# We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
boolean_prompt_mask = (prompt_mask == 1).to(device)
return prompt_embeds, boolean_prompt_mask
@torch.no_grad()
def encode_text(self, prompt):
device = self.text_encoder.device
batch = self.tokenizer(
prompt, max_length=self.max_text_seq_len, padding=True, truncation=True, return_tensors="pt")
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
encoder_hidden_states = self.text_encoder(
input_ids=input_ids, attention_mask=attention_mask)[0]
boolean_encoder_mask = (attention_mask == 1).to(device)
return encoder_hidden_states, boolean_encoder_mask
def encode_duration(self,duration):
return self.duration_emebdder(duration)
@torch.no_grad()
def inference_flow(self, prompt,
num_inference_steps=50,
timesteps=None,
guidance_scale=3,
duration=10,
disable_progress=False,
num_samples_per_prompt=1):
'''Only tested for single inference. Haven't test for batch inference'''
bsz = num_samples_per_prompt
device = self.transformer.device
scheduler = self.noise_scheduler
if not isinstance(prompt,list):
prompt = [prompt]
if not isinstance(duration,torch.Tensor):
duration = torch.tensor([duration],device=device)
classifier_free_guidance = guidance_scale > 1.0
duration_hidden_states = self.encode_duration(duration)
if classifier_free_guidance:
bsz = 2 * num_samples_per_prompt
encoder_hidden_states, boolean_encoder_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt=num_samples_per_prompt)
duration_hidden_states = duration_hidden_states.repeat(bsz,1,1)
else:
encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt,num_samples_per_prompt=num_samples_per_prompt)
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states)
masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan')))
pooled = torch.nanmean(masked_data, dim=1)
pooled_projection = self.fc(pooled)
encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
device,
timesteps,
sigmas
)
latents = torch.randn(num_samples_per_prompt,self.audio_seq_len,64)
weight_dtype = latents.dtype
progress_bar = tqdm(range(num_inference_steps), disable=disable_progress)
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
audio_ids = torch.arange(self.audio_seq_len).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
timesteps = timesteps.to(device)
latents = latents.to(device)
encoder_hidden_states = encoder_hidden_states.to(device)
for i, t in enumerate(timesteps):
latents_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
noise_pred = self.transformer(
hidden_states=latents_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=torch.tensor([t/1000],device=device),
guidance = None,
pooled_projections=pooled_projection,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=audio_ids,
return_dict=False,
)[0]
if classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents).prev_sample
return latents
def forward(self,
latents,
prompt,
duration=torch.tensor([10]),
sft=True
):
device = latents.device
audio_seq_length = self.audio_seq_len
bsz = latents.shape[0]
encoder_hidden_states, boolean_encoder_mask = self.encode_text(prompt)
duration_hidden_states = self.encode_duration(duration)
mask_expanded = boolean_encoder_mask.unsqueeze(-1).expand_as(encoder_hidden_states)
masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan')))
pooled = torch.nanmean(masked_data, dim=1)
pooled_projection = self.fc(pooled)
## Add duration hidden states to encoder hidden states
encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim)
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
audio_ids = torch.arange(audio_seq_length).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
if sft:
if self.uncondition:
mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
if len(mask_indices) > 0:
encoder_hidden_states[mask_indices] = 0
noise = torch.randn_like(latents)
u = compute_density_for_timestep_sampling(
weighting_scheme='logit_normal',
batch_size=bsz,
logit_mean=0,
logit_std=1,
mode_scale=None,
)
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
model_pred = self.transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps/1000,
return_dict=False)[0]
target = noise - latents
loss = torch.mean(
( (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1,
)
loss = loss.mean()
raw_model_loss, raw_ref_loss,implicit_acc = 0,0,0 ## default this to 0 if doing sft
else:
encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1)
pooled_projection = pooled_projection.repeat(2,1)
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1) ## Have to sample same noise for preferred and rejected
u = compute_density_for_timestep_sampling(
weighting_scheme='logit_normal',
batch_size=bsz//2,
logit_mean=0,
logit_std=1,
mode_scale=None,
)
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
timesteps = timesteps.repeat(2)
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
model_pred = self.transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps/1000,
return_dict=False)[0]
target = noise - latents
model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
model_losses_w, model_losses_l = model_losses.chunk(2)
model_diff = model_losses_w - model_losses_l
raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
with torch.no_grad():
ref_preds = self.ref_transformer(
hidden_states=noisy_model_input,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_projection,
img_ids=audio_ids,
txt_ids=txt_ids,
guidance=None,
timestep=timesteps/1000,
return_dict=False)[0]
ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean()
scale_term = -0.5 * self.beta_dpo
inside_term = scale_term * (model_diff - ref_diff)
implicit_acc = (scale_term * (model_diff - ref_diff) > 0).sum().float() / inside_term.size(0)
loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean()
## raw_model_loss, raw_ref_loss, implicit_acc is used to help to analyze dpo behaviour.
return loss, raw_model_loss, raw_ref_loss, implicit_acc

11
requirements.txt Normal file
View File

@ -0,0 +1,11 @@
torch==2.4.0
torchaudio==2.4.0
torchlibrosa==0.1.0
torchvision==0.19.0
transformers==4.44.0
diffusers==0.30.0
accelerate==0.34.2
datasets==2.21.0
librosa
tqdm
wandb

61
tangoflux Normal file
View File

@ -0,0 +1,61 @@
from diffusers import AutoencoderOobleck
import torch
from transformers import T5EncoderModel,T5TokenizerFast
from diffusers import FluxTransformer2DModel
from torch import nn
from typing import List
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_density_for_timestep_sampling
import copy
import torch.nn.functional as F
import numpy as np
from src.model import TangoFlux
from huggingface_hub import snapshot_download
from tqdm import tqdm
from typing import Optional,Union,List
from datasets import load_dataset, Audio
from math import pi
import json
import inspect
import yaml
from safetensors.torch import load_file
class TangoFluxInference:
def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
self.vae = AutoencoderOobleck()
paths = snapshot_download(repo_id=name)
vae_weights = load_file("{}/vae.safetensors".format(paths))
self.vae.load_state_dict(vae_weights)
weights = load_file("{}/tangoflux.safetensors".format(paths))
with open('{}/config.json'.format(paths),'r') as f:
config = json.load(f)
self.model = TangoFlux(config)
self.model.load_state_dict(weights,strict=False)
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
self.vae.to(device)
self.model.to(device)
def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
with torch.no_grad():
latents = self.model.inference_flow(prompt,
duration=duration,
num_inference_steps=steps,
guidance_scale=guidance_scale)
wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
waveform_end = int(duration * self.vae.config.sampling_rate)
wave = wave[:, :waveform_end]
return wave

BIN
vae.safetensors (Stored with Git LFS) Normal file

Binary file not shown.