BiliSakura commited on
Commit
b0d5d35
·
verified ·
1 Parent(s): bbb2e16

Update pMF-B-32/scheduler/scheduling_pmf.py

Browse files
pMF-B-32/scheduler/scheduling_pmf.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
9
+ from diffusers.utils import BaseOutput
10
+
11
+
12
+ @dataclass
13
+ class PMFSchedulerOutput(BaseOutput):
14
+ prev_sample: torch.Tensor
15
+
16
+
17
+ class PMFScheduler(SchedulerMixin, ConfigMixin):
18
+ """Scheduler for Pixel Mean Flow sampling (t: 1 -> 0)."""
19
+
20
+ order = 1
21
+
22
+ @register_to_config
23
+ def __init__(self, num_train_timesteps: int = 1000):
24
+ del num_train_timesteps
25
+ self.timesteps: Optional[torch.Tensor] = None
26
+ self.num_inference_steps: Optional[int] = None
27
+ self._step_index: Optional[int] = None
28
+
29
+ @property
30
+ def init_noise_sigma(self) -> float:
31
+ return 1.0
32
+
33
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device, None] = None) -> None:
34
+ if num_inference_steps < 1:
35
+ raise ValueError("num_inference_steps must be >= 1.")
36
+ self.num_inference_steps = num_inference_steps
37
+ self.timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device, dtype=torch.float32)
38
+ self._step_index = 0
39
+
40
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
41
+ del timestep
42
+ return sample
43
+
44
+ def _resolve_step_index(self, timestep: Union[float, torch.Tensor, None]) -> int:
45
+ if self._step_index is not None:
46
+ return self._step_index
47
+ if self.timesteps is None:
48
+ raise ValueError("Call `set_timesteps` before `step`.")
49
+ if timestep is None:
50
+ return 0
51
+
52
+ t_value = float(timestep) if not isinstance(timestep, torch.Tensor) else float(timestep.flatten()[0])
53
+ matches = (self.timesteps - t_value).abs() < 1e-6
54
+ if matches.any():
55
+ return int(matches.nonzero(as_tuple=False)[0].item())
56
+ return 0
57
+
58
+ def step(
59
+ self,
60
+ model_output: torch.Tensor,
61
+ timestep: Union[float, torch.Tensor, None],
62
+ sample: torch.Tensor,
63
+ return_dict: bool = True,
64
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
65
+ ) -> Union[PMFSchedulerOutput, Tuple[torch.Tensor]]:
66
+ del generator
67
+ if self.timesteps is None:
68
+ raise ValueError("Call `set_timesteps` before `step`.")
69
+
70
+ step_index = self._resolve_step_index(timestep)
71
+ if step_index >= len(self.timesteps) - 1:
72
+ raise ValueError("Scheduler has already reached the final timestep.")
73
+
74
+ t = self.timesteps[step_index]
75
+ t_next = self.timesteps[step_index + 1]
76
+ dt = (t - t_next).to(dtype=sample.dtype, device=sample.device)
77
+ while dt.ndim < sample.ndim:
78
+ dt = dt.unsqueeze(-1)
79
+
80
+ prev_sample = sample - dt * model_output
81
+ self._step_index = step_index + 1
82
+
83
+ if not return_dict:
84
+ return (prev_sample,)
85
+ return PMFSchedulerOutput(prev_sample=prev_sample)