| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import html |
| | from typing import List, Optional, Union |
| |
|
| | import regex as re |
| | import torch |
| | from transformers import AutoTokenizer, UMT5EncoderModel |
| |
|
| | from diffusers.configuration_utils import FrozenDict |
| | from diffusers.guiders import ClassifierFreeGuidance |
| | from diffusers.utils import is_ftfy_available, logging |
| | from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState |
| | from diffusers.modular_pipelines.modular_pipeline_utils import ( |
| | ComponentSpec, |
| | ConfigSpec, |
| | InputParam, |
| | OutputParam, |
| | ) |
| | from diffusers.modular_pipelines import WanModularPipeline |
| |
|
| |
|
| | if is_ftfy_available(): |
| | import ftfy |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def basic_clean(text): |
| | text = ftfy.fix_text(text) |
| | text = html.unescape(html.unescape(text)) |
| | return text.strip() |
| |
|
| |
|
| | def whitespace_clean(text): |
| | text = re.sub(r"\s+", " ", text) |
| | text = text.strip() |
| | return text |
| |
|
| |
|
| | def prompt_clean(text): |
| | text = whitespace_clean(basic_clean(text)) |
| | return text |
| |
|
| |
|
| | class WanRTTextEncoderStep(ModularPipelineBlocks): |
| | model_name = "wan" |
| |
|
| | @property |
| | def description(self) -> str: |
| | return "Text Encoder step that generate text_embeddings to guide the video generation" |
| |
|
| | @property |
| | def expected_components(self) -> List[ComponentSpec]: |
| | return [ |
| | ComponentSpec("text_encoder", UMT5EncoderModel), |
| | ComponentSpec("tokenizer", AutoTokenizer), |
| | ComponentSpec( |
| | "guider", |
| | ClassifierFreeGuidance, |
| | config=FrozenDict({"guidance_scale": 5.0}), |
| | default_creation_method="from_config", |
| | ), |
| | ] |
| |
|
| | @property |
| | def expected_configs(self) -> List[ConfigSpec]: |
| | return [] |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "prompt", |
| | description="The prompt or prompts to guide the video generation", |
| | ), |
| | InputParam( |
| | "negative_prompt", |
| | description="The prompt or prompts not to guide the video generation", |
| | ), |
| | InputParam( |
| | "prompt_embeds", |
| | type_hint=torch.Tensor, |
| | description="text embeddings used to guide the image generation", |
| | ), |
| | InputParam( |
| | "negative_prompt_embeds", |
| | type_hint=torch.Tensor, |
| | description="negative text embeddings used to guide the image generation", |
| | ), |
| | InputParam( |
| | "attention_kwargs", |
| | description="Additional keyword arguments to pass to the attention mechanism", |
| | ), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "prompt_embeds", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="text embeddings used to guide the image generation", |
| | ), |
| | OutputParam( |
| | "negative_prompt_embeds", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="negative text embeddings used to guide the image generation", |
| | ), |
| | ] |
| |
|
| | @staticmethod |
| | def check_inputs(block_state): |
| | if block_state.prompt is not None and ( |
| | not isinstance(block_state.prompt, str) |
| | and not isinstance(block_state.prompt, list) |
| | ): |
| | raise ValueError( |
| | f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}" |
| | ) |
| |
|
| | @staticmethod |
| | def _get_t5_prompt_embeds( |
| | components, |
| | prompt: Union[str, List[str]], |
| | max_sequence_length: int, |
| | device: torch.device, |
| | ): |
| | dtype = components.text_encoder.dtype |
| | prompt = [prompt] if isinstance(prompt, str) else prompt |
| | prompt = [prompt_clean(u) for u in prompt] |
| |
|
| | text_inputs = components.tokenizer( |
| | prompt, |
| | padding="max_length", |
| | max_length=max_sequence_length, |
| | truncation=True, |
| | add_special_tokens=True, |
| | return_attention_mask=True, |
| | return_tensors="pt", |
| | ) |
| | text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask |
| | seq_lens = mask.gt(0).sum(dim=1).long() |
| | prompt_embeds = components.text_encoder( |
| | text_input_ids.to(device), mask.to(device) |
| | ).last_hidden_state |
| | prompt_embeds = prompt_embeds.to(dtype=dtype) |
| | prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] |
| | prompt_embeds = torch.stack( |
| | [ |
| | torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) |
| | for u in prompt_embeds |
| | ], |
| | dim=0, |
| | ) |
| |
|
| | return prompt_embeds |
| |
|
| | @staticmethod |
| | def encode_prompt( |
| | components, |
| | prompt: str, |
| | device: Optional[torch.device] = None, |
| | num_videos_per_prompt: int = 1, |
| | prepare_unconditional_embeds: bool = True, |
| | negative_prompt: Optional[str] = None, |
| | prompt_embeds: Optional[torch.Tensor] = None, |
| | negative_prompt_embeds: Optional[torch.Tensor] = None, |
| | max_sequence_length: int = 512, |
| | ): |
| | r""" |
| | Encodes the prompt into text encoder hidden states. |
| | |
| | Args: |
| | prompt (`str` or `List[str]`, *optional*): |
| | prompt to be encoded |
| | device: (`torch.device`): |
| | torch device |
| | num_videos_per_prompt (`int`): |
| | number of videos that should be generated per prompt |
| | prepare_unconditional_embeds (`bool`): |
| | whether to use prepare unconditional embeddings or not |
| | negative_prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts not to guide the image generation. If not defined, one has to pass |
| | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
| | less than `1`). |
| | prompt_embeds (`torch.Tensor`, *optional*): |
| | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
| | provided, text embeddings will be generated from `prompt` input argument. |
| | negative_prompt_embeds (`torch.Tensor`, *optional*): |
| | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
| | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
| | argument. |
| | max_sequence_length (`int`, defaults to `512`): |
| | The maximum number of text tokens to be used for the generation process. |
| | """ |
| | device = device or components._execution_device |
| | prompt = [prompt] if isinstance(prompt, str) else prompt |
| | batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] |
| |
|
| | if prompt_embeds is None: |
| | prompt_embeds = WanRTTextEncoderStep._get_t5_prompt_embeds( |
| | components, prompt, max_sequence_length, device |
| | ) |
| |
|
| | if prepare_unconditional_embeds and negative_prompt_embeds is None: |
| | negative_prompt = negative_prompt or "" |
| | negative_prompt = ( |
| | batch_size * [negative_prompt] |
| | if isinstance(negative_prompt, str) |
| | else negative_prompt |
| | ) |
| |
|
| | if prompt is not None and type(prompt) is not type(negative_prompt): |
| | raise TypeError( |
| | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
| | f" {type(prompt)}." |
| | ) |
| | elif batch_size != len(negative_prompt): |
| | raise ValueError( |
| | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
| | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
| | " the batch size of `prompt`." |
| | ) |
| |
|
| | negative_prompt_embeds = WanRTTextEncoderStep._get_t5_prompt_embeds( |
| | components, negative_prompt, max_sequence_length, device |
| | ) |
| |
|
| | bs_embed, seq_len, _ = prompt_embeds.shape |
| | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) |
| | prompt_embeds = prompt_embeds.view( |
| | bs_embed * num_videos_per_prompt, seq_len, -1 |
| | ) |
| |
|
| | if prepare_unconditional_embeds: |
| | negative_prompt_embeds = negative_prompt_embeds.repeat( |
| | 1, num_videos_per_prompt, 1 |
| | ) |
| | negative_prompt_embeds = negative_prompt_embeds.view( |
| | batch_size * num_videos_per_prompt, seq_len, -1 |
| | ) |
| |
|
| | return prompt_embeds, negative_prompt_embeds |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, components: WanModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | |
| | block_state = self.get_block_state(state) |
| | self.check_inputs(block_state) |
| |
|
| | block_state.prepare_unconditional_embeds = False |
| | block_state.device = components._execution_device |
| |
|
| | |
| | ( |
| | block_state.prompt_embeds, |
| | block_state.negative_prompt_embeds, |
| | ) = WanRTTextEncoderStep.encode_prompt( |
| | components, |
| | block_state.prompt, |
| | block_state.device, |
| | 1, |
| | block_state.prepare_unconditional_embeds, |
| | block_state.negative_prompt, |
| | prompt_embeds=block_state.prompt_embeds, |
| | negative_prompt_embeds=block_state.negative_prompt_embeds, |
| | ) |
| | block_state.prompt_embeds = block_state.prompt_embeds.contiguous() |
| |
|
| | |
| | self.set_block_state(state, block_state) |
| | return components, state |
| |
|