Module ml4opf.loss_functions.penalty
Penalize constraint violations
Classes
class PenaltyLoss (v: OPFViolation,
exclude_keys: str | list[str] | None = None,
multipliers: torch.Tensor | dict[str, torch.Tensor] | None = None)-
PenaltyLoss penalizes constraint violations in the loss.
exclude_keys
is either None to use all violations, "all" to skip all violations, or a list of keys to skip specific violations.Initialize PenaltyLoss module.
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self,
base_loss: torch.Tensor,
exclude_keys: str | list[str] | None = None,
**calc_violation_inputs: torch.Tensor) ‑> torch.Tensor-
Compute the PenaltyLoss for a batch of samples.
def init_mults(self, multipliers: torch.Tensor | dict[str, torch.Tensor] | None = None)
-
Initialize multipliers for each constraint type.