🧩 PLADIS: Pushing the Limits of Attention in Diffusion Models at Inference Time by Leveraging Sparsity (ICCV'25)

August 1, 2025 Β· View on GitHub

logo

🧩 PLADIS: Pushing the Limits of Attention in Diffusion Models at Inference Time by Leveraging Sparsity (ICCV'25)

by Kwanyoung Kim, Byeongsu Sim

   

teaser_page1

🚢 Overview

SDXL Guidance methods: SDXL-CFG | SDXL-PAG | SDXL-SEG | SDXL-FreeU
SDXL Guidnace Distilled Models: SDXL-DMD2 | SDXL-Lightening | Hyper-SDXL
FLUX: FLUX-Schnell | FLUX-dev
ControlNet: SDXL-ControlNet

πŸ”₯πŸ”₯ News

  • (πŸ”₯ New) [2025/7/31] πŸš€PLADIS codes are released.
  • (πŸ”₯ New) [2025/6/25] PLADIS was accepted to ICCV2025. πŸŽ‰πŸŽ‰πŸŽ‰

πŸ’‘ Introduction

Existing guidance methods require extra inference steps due to undesired paths, such as null conditions or perturbing self-attention with an identity matrix or blurred attention weights. In contrast, PLADIS avoids additional inference paths by extrapolating between sparse and dense attentions within all cross-attention modules. This method can be seamlessly integrated with existing guidance approaches (CFG, PAG, SEG) and even guidance-distilled models (DMD2, SDXL-Lightening, HyperSDXL) by simply replacing the cross-attention module. PLADIS significantly improves text alignment and sample generation quality without requiring additional training or extra inference steps. It is compatible with SDXL and advanced model backbones like FLUX, as well as downstream tasks such as ControlNet.

Please check our paper for detail.

teaser_page2

πŸ”§ Dependencies and Installation

We use the following version of diffusers:

- diffuser 0.33.1

For computing sparse attention with alpha-entmax, please install the following package

pip install entmax

πŸš€ Quick Start

We provide example code with and without PLADIS using various backbone models and tasks.

Parameter:

  • pladis_scale: The scale of PLADIS. We generally fix this parameter to 1.5 and 2.0. Please adjust it coresponding to the model backbone and case.

🧩 Using PLADIS with SDXL

1. How to use PLADIS with CFG

Click to show all
from pipeline.pipeline_sdxl_seg import PLADISStableDiffusionXLPipeline
from diffusers.utils.torch_utils import randn_tensor
from accelerate.utils import set_seed
from diffusers.utils import make_image_grid

seed = 7896
set_seed(seed)
device = "cuda"
latent_input = randn_tensor(shape=(1,4,128,128),device=device, dtype=torch.float16)

pipe = PLADISStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe.to(device)

prompt = "a man and woman sit next to each other in front of some wine "

# without PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=4.0,
    num_inference_steps=25,
    latents = latent_input,
).images[0]
output.append(image)
# with PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=5.0,
    num_inference_steps=25,
    pladis_scale = 2.0,
    latents = latent_input,
).images[0]
output.append(image)

grid = make_image_grid(output, rows=1, cols=2)
filename = f"CFG_comparison.png"
grid.save(filename, format='PNG')

2. How to use PLADIS with CFG and PAG

Click to show all
from pipeline.pipeline_sdxl_seg import PLADISStableDiffusionXLPipeline
from diffusers.utils.torch_utils import randn_tensor
from accelerate.utils import set_seed
from diffusers.utils import make_image_grid

seed = 7896
set_seed(seed)
device = "cuda"
latent_input = randn_tensor(shape=(1,4,128,128),device=device, dtype=torch.float16)

pipe = PLADISStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe.to(device)

prompt = "a man and woman sit next to each other in front of some wine "
output = []
# without PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=4.0,
    pag_scale=3.0,
    num_inference_steps=25,
    latents = latent_input,
).images[0]
output.append(image)
# with PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=4.0,
    pag_scale=3.0,
    num_inference_steps=25,
    pladis_scale = 2.0,
    latents =latent_input,
).images[0]
output.append(image)

grid = make_image_grid(output, rows=1, cols=2)
filename = f"PAG_comparison.png"
grid.save(filename, format='PNG')

3. How to use PLADIS with CFG and SEG

Click to show all
from pipeline.pipeline_sdxl_seg import PLADISStableDiffusionXLPipeline
from diffusers.utils.torch_utils import randn_tensor
from accelerate.utils import set_seed
from diffusers.utils import make_image_grid

seed = 7896
set_seed(seed)
device = "cuda"
latent_input = randn_tensor(shape=(1,4,128,128),device=device, dtype=torch.float16)

pipe = PLADISStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe.to(device)

output = []
prompt = "a man and woman sit next to each other in front of some wine "
# without PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=4.0,
    num_inference_steps=25,
    seg_scale= 3.0,
    seg_blur_sigma=100.0,
    latents = latent_input,
).images[0]
output.append(image)
# with PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=4.0,
    num_inference_steps=25,
    seg_scale= 3.0,
    seg_blur_sigma=100.0,
    pladis_scale = 2.0,
    latents = latent_input,
).images[0]
output.append(image)

grid = make_image_grid(output, rows=1, cols=2)
grid.save("SEG_comparison.png", format='PNG')

4. How to use PLADIS with FreeU

Click to show all
from pipeline.pipeline_sdxl import PLADISStableDiffusionXLPipeline
from diffusers.utils.torch_utils import randn_tensor
from accelerate.utils import set_seed
from diffusers.utils import make_image_grid

seed = 7896
set_seed(seed)
device = "cuda"
latent_input = randn_tensor(shape=(1,4,128,128),device=device, dtype=torch.float16)

pipe = StableDiffusionXLPLADISPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
pipe.to(device)

output = []
prompt = "a man and woman sit next to each other in front of some wine "
# without PLADIS
image = pipe(
    prompt=prompt,
    num_inference_steps=25,
    latents = latent_input,
).images[0]
output.append(image)
# with PLADIS
image = pipe(
    prompt=prompt,
    num_inference_steps=25,
    pladis_scale = 1.5,
    latents = latent_input,
).images[0]
output.append(image)

grid = make_image_grid(output, rows=1, cols=2)
grid.save("FreeU_comparison.png", format='PNG')

🧩 Using PLADIS with Guidance Distilled Model

1. How to use PLADIS with DMD2

Click to show all
from diffusers import LCMScheduler
from diffusers.utils.torch_utils import randn_tensor
from accelerate.utils import set_seed
from pipeline.pipeline_sdxl import PLADISStableDiffusionXLPipeline
from huggingface_hub import hf_hub_download

device="cuda"
seed = 7896
set_seed(seed)
latent_input = randn_tensor(shape=(1,4,128,128),generator=None, device=device, dtype=torch.float16)

repo_name = "tianweiy/DMD2"
ckpt_name = f"dmd2_sdxl_4step_lora_fp16.safetensors"
pipe = PLADISStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
pipe.fuse_lora(lora_scale=1.0) 
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.to(device)

prompt = 'A photo of sad cat, blue color.'
output = []
# without PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    timesteps=[999, 749, 499, 249],
    latents = latent_input,
).images[0]
output.append(image)
# with PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    timesteps=[999, 749, 499, 249],
    pladis_scale = 1.5,
    latents = latent_input,
).images[0]
output.append(image)

grid = make_image_grid(output, rows=1, cols=2)
filename = f"dmd_comparison.png"
grid.save(os.path.join(filename), format='PNG')

2. How to use PLADIS with SDXL-Lightening

Click to show all
from diffusers import EulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from accelerate.utils import set_seed
from pipeline.pipeline_sdxl import PLADISStableDiffusionXLPipeline
from huggingface_hub import hf_hub_download

device="cuda"
seed = 7896
set_seed(seed)
latent_input = randn_tensor(shape=(1,4,128,128),generator=None, device=device, dtype=torch.float16)

repo = "ByteDance/SDXL-Lightning"
ckpt_name = f"sdxl_lightning_4step_lora.safetensors"
pipe = PLADISStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe.load_lora_weights(hf_hub_download(repo, ckpt_name))
pipe.fuse_lora(lora_scale=1.0) 
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.to(device)

prompt = 'A photo of sad cat, blue color.'
output = []
# without PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    latents = latent_input,
).images[0]
output.append(image)
# with PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    pladis_scale = 1.5,
    latents = latent_input,
).images[0]
output.append(image)

grid = make_image_grid(output, rows=1, cols=2)
filename = f"light_comparison.png"
grid.save(os.path.join(filename), format='PNG')

3. How to use PLADIS with Hyper-SDXL

Click to show all
from diffusers import DDIMScheduler
from diffusers.utils.torch_utils import randn_tensor
from accelerate.utils import set_seed
from pipeline.pipeline_sdxl import PLADISStableDiffusionXLPipeline
from huggingface_hub import hf_hub_download

device="cuda"
seed = 7896
set_seed(seed)
latent_input = randn_tensor(shape=(1,4,128,128),generator=None, device=device, dtype=torch.float16)

repo_name = "ByteDance/Hyper-SD"
ckpt_name = f"Hyper-SDXL-4steps-lora.safetensors"
pipe = PLADISStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe.load_lora_weights(hf_hub_download(repo_name, ckpt_name))
pipe.fuse_lora(lora_scale=1.0) 
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.to(device)

prompt = 'A photo of sad cat, blue color.'
output = []
# without PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    latents = latent_input,
).images[0]
output.append(image)
# with PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    pladis_scale = 1.5,
    latents = latent_input,
).images[0]
output.append(image)

grid = make_image_grid(output, rows=1, cols=2)
filename = f"hyper_comparison.png"
grid.save(os.path.join(filename), format='PNG')

🧩 Using PLADIS with Flux

1. How to use PLADIS with Schnell

Click to show all
import torch
from pipeline.pipeline_flux_pladis import PLADISFluxPipeline
from diffusers.utils import make_image_grid

pipe = PLADISFluxPipeline.from_pretrained(
  "black-forest-labs/FLUX.1-schnell",
  torch_dtype=torch.bfloat16,
)
pipe.to("cuda")

prompt = 'a cyberfunk happy smiling bear with a neon sign that says "PLADIS"'

output = []
# without PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    max_sequence_length=256,
    generator=torch.Generator(device="cuda").manual_seed(0),
).images[0]
output.append(image)
# with PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    pladis_scale = 1.5,
    max_sequence_length=256,
    generator=torch.Generator(device="cuda").manual_seed(0),
).images[0]
output.append(image)
grid = make_image_grid(output, rows=1, cols=2)
grid.save("flux_schnell_comparison.png", format='PNG')

2. How to use PLADIS with Dev

Click to show all
import torch
from pipeline.pipeline_flux_pladis import PLADISFluxPipeline
from diffusers.utils import make_image_grid

pipe = PLADISFluxPipeline.from_pretrained(
  "black-forest-labs/FLUX.1-dev",
  torch_dtype=torch.bfloat16,
)
pipe.to("cuda")

prompt = 'a cyberfunk happy smiling bear with a neon sign that says "PLADIS"'
output = []
# without PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=50,
    max_sequence_length=256,
    generator=torch.Generator(device="cuda").manual_seed(0),
).images[0]
output.append(image)
# with PLADIS
image = pipe(
    prompt=prompt,
    guidance_scale=0.0,
    num_inference_steps=50,
    pladis_scale = 1.5,
    max_sequence_length=256,
    generator=torch.Generator(device="cuda").manual_seed(0),
).images[0]
output.append(image)
grid = make_image_grid(output, rows=1, cols=2)
grid.save("flux_dev.png", format='PNG')

🧩 Using PLADIS with ControlNet

1. How to use PLADIS under Depth Condition

Click to show all
import torch
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from pipeline.pipeline_controlnet import PLADISStableDiffusionXLControlNetPipeline
from diffusers import ControlNetModel, AutoencoderKL
from PIL import Image
import numpy as np

device=f"cuda"
seed = 7896
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-depth-sdxl-1.0",
    torch_dtype=torch.float16
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

image_name = f"asset/car.jpg"

def get_depth_map(image):
    image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
    with torch.no_grad(), torch.autocast(device):
        depth_map = depth_estimator(image).predicted_depth

    depth_map = torch.nn.functional.interpolate(
        depth_map.unsqueeze(1),
        size=(1024, 1024),
        mode="bicubic",
        align_corners=False,
    )
    depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
    depth_map = (depth_map - depth_min) / (depth_max - depth_min)
    image = torch.cat([depth_map] * 3, dim=1)

    image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
    image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
    return image

image = load_image(image_name)
image = get_depth_map(image)  

pipe = StableDiffusionXLSEGControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    torch_dtype=torch.float16,
)
prompt = "red car heavliy broken in the forest"
negative_prompt = "low quality, bad quality, sketches"
controlnet_conditioning_scale = 0.5

output = []
# without PLADIS
generator = torch.Generator(device=device).manual_seed(seed)
control_image = pipe(
        prompt,
        negative_prompt=negative_prompt,
        image=image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=25,
        guidance_scale=1.0,
        seg_scale=3.0,
        seg_blur_sigma=10000000000.0,
        seg_applied_layers=['mid'],
        generator=generator,
    ).images[0]
output.append(control_image)
# with PLADIS
generator = torch.Generator(device=device).manual_seed(seed)
control_image = pipe(
        prompt,
        negative_prompt=negative_prompt,
        image=image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=25,
        guidance_scale=1.0,
        seg_scale=3.0,
        seg_blur_sigma=10000000000.0,
        seg_applied_layers=['mid'],
        pladis_scale = 3.0,
        generator=generator,
    ).images[0]
output.append(control_image)
grid = make_image_grid(output, rows=1, cols=3)
grid.save("Controlnet_comparison.png", format='PNG')

πŸ’ͺTo-Do List

We will try our best to achieve

  • [βœ…] PLADIS + ControlNet
  • [βœ…] Flux Inference Code (Including with schnell and dev)
  • [βœ…] SDXL Inference Code (Including with CFG, PAG, SEG, FreeU)

πŸ€— Acknowledgements

Thanks to the following open-sourced codebase for their wonderful work and codebase!

⚠️ License Information

We use standard licenses from the community for the models that we used in this paper. For further information, please refer to the paper.
In particular, the FLUX.1-dev model is used strictly for non-commercial research purposes, in accordance with its license.
We did not modify or fine-tune the model, and used it only during inference to evaluate and enhance performance. The model itself is not redistributed.
Please refer to: https://huggingface.co/black-forest-labs/FLUX.1-dev
This code is intended solely for academic and non-commercial research.

πŸ“–BibTeX

@article{kim2025pladis,
  title={PLADIS: Pushing the Limits of Attention in Diffusion Models at Inference Time by Leveraging Sparsity},
  author={Kim, Kwanyoung and Sim, Byeongsu},
  journal={arXiv preprint arXiv:2503.07677},
  year={2025}
}