Module ml4opf.loss_functions.penalty

Penalize constraint violations

Classes

class PenaltyLoss (v: OPFViolation,
exclude_keys: str | list[str] | None = None,
multipliers: torch.Tensor | dict[str, torch.Tensor] | None = None)

PenaltyLoss penalizes constraint violations in the loss.

exclude_keys is either None to use all violations, "all" to skip all violations, or a list of keys to skip specific violations.

Initialize PenaltyLoss module.

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self,
base_loss: torch.Tensor,
exclude_keys: str | list[str] | None = None,
**calc_violation_inputs: torch.Tensor) ‑> torch.Tensor

Compute the PenaltyLoss for a batch of samples.

def init_mults(self, multipliers: torch.Tensor | dict[str, torch.Tensor] | None = None)

Initialize multipliers for each constraint type.