Instructions to use BiliSakura/pMF-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/pMF-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/pMF-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- Draw Things
- DiffusionBee
Update pMF-B-32/scheduler/scheduling_pmf.py
Browse files
pMF-B-32/scheduler/scheduling_pmf.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 8 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
| 9 |
+
from diffusers.utils import BaseOutput
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class PMFSchedulerOutput(BaseOutput):
|
| 14 |
+
prev_sample: torch.Tensor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PMFScheduler(SchedulerMixin, ConfigMixin):
|
| 18 |
+
"""Scheduler for Pixel Mean Flow sampling (t: 1 -> 0)."""
|
| 19 |
+
|
| 20 |
+
order = 1
|
| 21 |
+
|
| 22 |
+
@register_to_config
|
| 23 |
+
def __init__(self, num_train_timesteps: int = 1000):
|
| 24 |
+
del num_train_timesteps
|
| 25 |
+
self.timesteps: Optional[torch.Tensor] = None
|
| 26 |
+
self.num_inference_steps: Optional[int] = None
|
| 27 |
+
self._step_index: Optional[int] = None
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def init_noise_sigma(self) -> float:
|
| 31 |
+
return 1.0
|
| 32 |
+
|
| 33 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device, None] = None) -> None:
|
| 34 |
+
if num_inference_steps < 1:
|
| 35 |
+
raise ValueError("num_inference_steps must be >= 1.")
|
| 36 |
+
self.num_inference_steps = num_inference_steps
|
| 37 |
+
self.timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device, dtype=torch.float32)
|
| 38 |
+
self._step_index = 0
|
| 39 |
+
|
| 40 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
| 41 |
+
del timestep
|
| 42 |
+
return sample
|
| 43 |
+
|
| 44 |
+
def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
|
| 45 |
+
if self._step_index is not None:
|
| 46 |
+
return self._step_index
|
| 47 |
+
if self.timesteps is None:
|
| 48 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 49 |
+
if timestep is None:
|
| 50 |
+
return 0
|
| 51 |
+
|
| 52 |
+
t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
|
| 53 |
+
matches = (self.timesteps - t_value).abs() < 1e-6
|
| 54 |
+
if matches.any():
|
| 55 |
+
return int(matches.nonzero(as_tuple=False)[0].item())
|
| 56 |
+
return 0
|
| 57 |
+
|
| 58 |
+
def step(
|
| 59 |
+
self,
|
| 60 |
+
model_output: torch.Tensor,
|
| 61 |
+
timestep: Union[float, torch.Tensor, None],
|
| 62 |
+
sample: torch.Tensor,
|
| 63 |
+
return_dict: bool = True,
|
| 64 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 65 |
+
) -> Union[PMFSchedulerOutput, Tuple[torch.Tensor]]:
|
| 66 |
+
del generator
|
| 67 |
+
if self.timesteps is None:
|
| 68 |
+
raise ValueError("Call `set_timesteps` before `step`.")
|
| 69 |
+
|
| 70 |
+
step_index = self._resolve_step_index(timestep)
|
| 71 |
+
if step_index >= len(self.timesteps) - 1:
|
| 72 |
+
raise ValueError("Scheduler has already reached the final timestep.")
|
| 73 |
+
|
| 74 |
+
t = self.timesteps[step_index]
|
| 75 |
+
t_next = self.timesteps[step_index + 1]
|
| 76 |
+
dt = (t - t_next).to(dtype=sample.dtype, device=sample.device)
|
| 77 |
+
while dt.ndim < sample.ndim:
|
| 78 |
+
dt = dt.unsqueeze(-1)
|
| 79 |
+
|
| 80 |
+
prev_sample = sample - dt * model_output
|
| 81 |
+
self._step_index = step_index + 1
|
| 82 |
+
|
| 83 |
+
if not return_dict:
|
| 84 |
+
return (prev_sample,)
|
| 85 |
+
return PMFSchedulerOutput(prev_sample=prev_sample)
|