Package ml4opf
ML4OPF: Machine Learning for OPF
This repository contains a collection of tools for applying machine learning to the optimal power flow (OPF) problem. Below are some common usage patterns:
Loading data
This is probably the most common usage, especially for those who already have their own models and wish to evaluate on the PGLearn datasets. ML4OPF makes loading data and splitting training/testing sets easy and reproducible.
from ml4opf import DCProblem
data_dir = ... # path to folder containing the data
problem = DCProblem(data_dir, **kwargs)
# extract tensors
train_pd = problem.train_data["input/pd"]
train_pg = problem.train_data["primal/pg"]
train_va = problem.train_data["primal/va"]
test_pd = problem.test_data["input/pd"]
test_pg = problem.test_data["primal/pg"]
test_va = problem.test_data["primal/va"]
# create a PyTorch dataset
torch_dataset, slices = problem.make_dataset()
Computing residuals
The ML4OPF OPFViolation modules provide a fast (using torch.jit
),
standard, and convenient way to: calculate the residuals/violations
of the OPF constraints, compute the objective function,
and other useful problem data such as incidence matrices.
v = problem.violation
pg_lower, pg_upper = v.pg_bound_residual(train_pg) # supply clamp=True to report violations only
obj = v.objective(train_pg)
gen_incidence = v.generator_incidence
Note that you can use the underlying functions directly without instantiating
the OPFViolation class by accessing ml4opf.functional
.
This allows to perform the calculations without using the data parsing or caching logic,
but requires the user to adopt the functional interface (ml4opf.functional
) vs. the object-oriented interface (ml4opf.formulations
).
import ml4opf.functional as MOF
gen_incidence = MOF.generator_incidence(v.gen_bus, v.n_bus, v.n_gen)
obj = MOF.DCP.objective(train_pg, v.c0, v.c1)
Implementing an OPFModel
In order to use the ML4OPF evaluation tools, you need to subclass the
OPFModel
class and implement a few methods. The typical pattern is to
first write your model in the typical PyTorch fashion - subclassing torch.nn.Module
.
Then, subclass OPFModel
and implement the required methods. Below is an example
where the original model is MyPyTorchModel
and the wrapper is MyDCPModel
.
import torch
from ml4opf import DCPModel
N_LOADS = problem.violation.n_load
N_GEN = problem.violation.n_gen
N_BUS = problem.violation.n_bus
class MyPyTorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(N_LOADS, 16)
self.fc2 = torch.nn.Linear(16, N_GEN)
self.fc3 = torch.nn.Linear(16, N_BUS)
def forward(self, pd):
x = torch.relu(self.fc1(pd))
pg_pred = self.fc2(x)
va_pred = self.fc3(x)
return pg_pred, va_pred
class MyDCPModel(DCPModel):
def __init__(self, pytorch_model, problem):
super().__init__()
self.model = pytorch_model
self.problem = problem
def save_checkpoint(self, path_to_folder):
torch.save(self.model.state_dict(), f"{path_to_folder}/model.pth")
@classmethod
def load_from_checkpoint(cls, path_to_folder, problem):
pytorch_model = MyPyTorchModel()
pytorch_model.load_state_dict(torch.load(f"{path_to_folder}/model.pth"))
return cls(pytorch_model, problem)
def predict(self, pd):
pg, va = self.model(pd)
return {"pg": pg, "va": va}
Using repair layers
A common issue with learning OPF is that the model may predict
infeasible solutions. The ml4opf.layers
module provides a collection
of differentiable layers that can be used to repair infeasible solutions. For example,
the BoundRepair
layer can be used to repair solutions that violate
bound constraints. The output of BoundRepair
is guaranteed to be within
the specified bounds.
from ml4opf.layers import BoundRepair
class BoundRepairPyTorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(N_LOADS, 16)
self.fc2 = torch.nn.Linear(16, N_GEN)
self.fc3 = torch.nn.Linear(16, N_BUS)
self.bound_repair = BoundRepair(xmin=v.pmin, xmax=v.pmax, method="softplus")
def forward(self, pd):
x = torch.relu(self.fc1(pd))
pg_pred = self.bound_repair(self.fc2(x))
va_pred = self.fc3(x)
return pg_pred, va_pred
The source code is organized into several submodules:
Sub-modules
ml4opf.formulations
-
ML4OPF Formulations
ml4opf.functional
-
Functional interface
ml4opf.layers
-
ML4OPF Layers
ml4opf.loss_functions
-
ML4OPF Loss Functions
ml4opf.models
-
ML4OPF Models …
ml4opf.parsers
-
ML4OPF Parsers
ml4opf.viz
-
Visualization utilities (plots & tables)
Classes
class ACModel
-
OPFModel
for ACOPFAncestors
- OPFModel
- abc.ABC
Subclasses
Class variables
var problem : ACProblem
var violation : ACViolation
Methods
def evaluate_model(self, reduction: str | None = None, inner_reduction: str | None = None) ‑> dict[str, torch.Tensor]
-
Evaluate the model on the test data.
Args
reduction
:str
, optional- Reduction method for the metrics. Defaults to None. Must be one of "mean", "sum","max", "none". If specified, each value in the returned dictionary will be a scalar. Otherwise, they are arrays of shape (n_test_samples,)
inner_reduction
:str
, optional- Reduction method for turning metrics calculated per component to per sample. Defaults to None. Must be one of "mean", "sum","max", "none".
Returns
dict[str, Tensor]
-
Dictionary containing Tensor metrics of the model's performance.
vm_lower
: Lower bound on the voltage magnitude.vm_upper
: Upper bound on the voltage magnitude.pg_lower
: Lower bound on the real power generation.pg_upper
: Upper bound on the real power generation.qg_lower
: Lower bound on the reactive power generation.qg_upper
: Upper bound on the reactive power generation.thrm_1
: Thermal limit violation fromthrm_2
: Thermal limit violation top_balance
: Active power balance violation.q_balance
: Reactive power balance violation.pg_mae
: Mean absolute error of the real power generation.qg_mae
: Mean absolute error of the reactive power generation.vm_mae
: Mean absolute error of the voltage magnitude.va_mae
: Mean absolute error of the voltage angle. (if not bus-wise and va not in predictions, skipped)dva_mae
: Mean absolute error of the angle difference. (only if not bus-wise)obj_mape
: Mean absolute percent error of the objective value.
def predict(self, pd: torch.Tensor, qd: torch.Tensor) ‑> dict[str, torch.Tensor]
-
Predict the ACOPF primal solution for a given set of loads.
Args
pd
:Tensor
- Active power demand per load.
qd
:Tensor
- Reactive power demand per load.
Returns
dict[str, Tensor]
-
Dictionary containing the predicted primal solution.
pg
: Active power generation per generator or per bus.qg
: Reactive power generation per generator or per bus.vm
: Voltage magnitude per bus.va
: Voltage angle per bus.
Inherited members
class ACProblem (data_directory: str, dataset_name: str = 'ACOPF', **parse_kwargs)
-
OPFProblem
for ACOPFAncestors
- OPFProblem
- abc.ABC
Instance variables
prop default_combos : dict[str, list[str]]
-
Default combos for ACOPF:
-
input: pd, qd
-
target: pg, qg, vm, va
-
prop default_order : list[str]
-
Default order for ACOPF: input, target
prop feasibility_check : dict[str, str]
-
Default feasibility check for ACOPF:
-
termination_status: "LOCALLY_SOLVED"
-
primal_status: "FEASIBLE_POINT"
-
dual_status: "FEASIBLE_POINT"
-
prop violation : ACViolation
-
ACPViolation
object, created upon first access.
Inherited members
class ACViolation (data: dict[str, torch.Tensor])
-
OPFViolation
for ACOPFInitialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- OPFViolation
- torch.nn.modules.module.Module
- abc.ABC
Methods
def angle_difference(self, va: torch.Tensor) ‑> torch.Tensor
-
Compute the angle differences per branch given the voltage angles per bus.
\text{dva} = \boldsymbol{\theta}_{f} - \boldsymbol{\theta}_{t}
Args
va
:Tensor
- Voltage angles per bus ( \boldsymbol{\theta} ). (batch_size, nbus)
Returns
Tensor
- Angle differences per branch. (batch_size, nbranch)
def balance_residual(self,
pd: torch.Tensor,
qd: torch.Tensor,
pg: torch.Tensor,
qg: torch.Tensor,
vm: torch.Tensor,
pf: torch.Tensor,
pt: torch.Tensor,
qf: torch.Tensor,
qt: torch.Tensor,
clamp: bool = False,
embed_method: str = 'pad') ‑> tuple[torch.Tensor, torch.Tensor]-
Calculate the power balance residual.
Component-wise tensors are first embedded to the bus level using
embed_method
.The shunt parameters g_s, b_s are assumed to be constant, matching the reference case.
\text{p_viol} = \text{pg_bus} - \text{pd_bus} - \text{pt_bus} - \text{pf_bus} - \text{gs_bus} \times \text{vm}^2 \text{q_viol} = \text{qg_bus} - \text{qd_bus} - \text{qt_bus} - \text{qf_bus} + \text{bs_bus} \times \text{vm}^2
Args
pd
:Tensor
- Active power demand per bus. (batch_size, nbus)
qd
:Tensor
- Reactive power demand per bus. (batch_size, nbus)
pg
:Tensor
- Active power generation per generator. (batch_size, ngen)
qg
:Tensor
- Reactive power generation per generator. (batch_size, ngen)
vm
:Tensor
- Voltage magnitude per bus. (batch_size, nbus)
pf
:Tensor
- Active power flow from bus per branch. (batch_size, nbranch)
pt
:Tensor
- Active power flow to bus per branch. (batch_size, nbranch)
qf
:Tensor
- Reactive power flow from bus per branch. (batch_size, nbranch)
qt
:Tensor
- Reactive power flow to bus per branch. (batch_size, nbranch)
clamp
:bool
, optional- Apply an absolute value to the residual. Defaults to False.
embed_method
:str
, optional- Embedding method for bus-level components. Defaults to 'pad'. Must be one of 'pad', 'dense_matrix', or 'matrix. See
IncidenceMixin.*_to_bus
.
Returns
Tensor
- Power balance residual for active power. (batch_size, nbus)
Tensor
- Power balance residual for reactive power. (batch_size, nbus)
def calc_violations(self,
pd: torch.Tensor,
qd: torch.Tensor,
pg: torch.Tensor,
qg: torch.Tensor,
vm: torch.Tensor,
va: torch.Tensor | None = None,
dva: torch.Tensor | None = None,
flows: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
reduction: str | None = 'mean',
clamp: bool = True) ‑> dict[str, torch.Tensor]-
Calculate the violation of all the constraints.
The reduction is applied across the component dimension - e.g., 'mean' will do violation.mean(dim=1) where each violation is (batch, components)
Args
pd
:Tensor
- Real power demand. (batch, loads)
qd
:Tensor
- Reactive power demand. (batch, loads)
pg
:Tensor
- Real power generation. (batch, gens)
qg
:Tensor
- Reactive power generation. (batch, gens)
vm
:Tensor
- Voltage magnitude. (batch, buses)
va
:Tensor
, optional- Voltage angle. (batch, buses)
dva
:Tensor
, optional- Voltage angle difference. (batch, branches)
flows
:tuple[Tensor, Tensor, Tensor, Tensor]
, optional- Power flows. (pf, pt, qf, qt)
reduction
:str
, optional- Reduction method. Defaults to 'mean'. Must be one of 'mean', 'sum', 'none'.
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to True.
Returns
- dict[str, Tensor]: Dictionary of violations.
vm_lower
: Voltage magnitude lower bound violation.vm_upper
: Voltage magnitude upper bound violation.pg_lower
: Real power generation lower bound violation.pg_upper
: Real power generation upper bound violation.qg_lower
: Reactive power generation lower bound violation.qg_upper
: Reactive power generation upper bound violation.thrm_1
: Thermal limit from violation.thrm_2
: Thermal limit to violation.p_balance
: Real power balance violation.q_balance
: Reactive power balance violation.dva_lower
: Voltage angle difference lower bound violation.dva_upper
: Voltage angle difference upper bound violation. def dva_bound_residual(self, dva: torch.Tensor, clamp: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]
-
Calculate the voltage angle difference bound residual.
g_{\text{lower}} = \text{angmin} - \text{dva} g_{\text{upper}} = \text{dva} - \text{angmax}
Args
dva
:Tensor
- Voltage angle difference per branch. (batch_size, nbranch)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Lower bound residual. (batch_size, nbranch)
Tensor
- Upper bound residual. (batch_size, nbranch)
def flows_from_voltage(self, vm: torch.Tensor, dva: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
-
Compute the power flows given the voltage magnitude and angle differences.
Args
vm
:Tensor
- Voltage magnitude per bus ( \mathbf{v} ). (batch_size, nbus)
dva
:Tensor
- Angle differences per branch ( \boldsymbol{\theta}_f - \boldsymbol{\theta}_t ). (batch_size, nbranch)
Returns
Tensor
- Real power flow per branch ( \mathbf{p}_f ). (batch_size, nbranch)
Tensor
- Real power flow per branch ( \mathbf{p}_t ). (batch_size, nbranch)
Tensor
- Reactive power flow per branch ( \mathbf{q}_f ). (batch_size, nbranch)
Tensor
- Reactive power flow per branch ( \mathbf{q}_t ). (batch_size, nbranch)
def flows_from_voltage_bus(self, vm: torch.Tensor, va: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
-
Compute the power flows given the voltage magnitude per bus and voltage angles per bus.
This function computes angle differences then calls
ACPViolation.flows_from_voltage
. See the docstring ofACPViolation.flows_from_voltage
for more details.Args
vm
:Tensor
- Voltage magnitude per bus ( \mathbf{v} ). (batch_size, nbus)
va
:Tensor
- Voltage angle per bus ( \boldsymbol{\theta} ). (batch_size, nbus)
Returns
Tensor
- Real power flow per branch ( \mathbf{p}_f ). (batch_size, nbranch)
Tensor
- Real power flow per branch ( \mathbf{p}_t ). (batch_size, nbranch)
Tensor
- Reactive power flow per branch ( \mathbf{q}_f ). (batch_size, nbranch)
Tensor
- Reactive power flow per branch ( \mathbf{q}_t ). (batch_size, nbranch)
def objective(self, pg: torch.Tensor) ‑> torch.Tensor
-
Compute the objective function given the active power generation per generator.
Args
pg
:Tensor
- Active power generation per generator. (batch_size, ngen)
Returns
Tensor
- Objective function value. (batch_size)
def pg_bound_residual(self, pg: torch.Tensor, clamp: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]
-
Calculate the active power generation bound residual.
g_{\text{lower}} = \text{pmin} - \text{pg} g_{\text{upper}} = \text{pg} - \text{pmax}
Args
pg
:Tensor
- Active power generation per generator. (batch_size, ngen)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Lower bound residual. (batch_size, ngen)
Tensor
- Upper bound residual. (batch_size, ngen)
def qg_bound_residual(self, qg: torch.Tensor, clamp: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]
-
Calculate the reactive power generation bound residual.
g_{\text{lower}} = \text{qmin} - \text{qg} g_{\text{upper}} = \text{qg} - \text{qmax}
Args
qg
:Tensor
- Reactive power generation per generator. (batch_size, ngen)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Lower bound residual. (batch_size, ngen)
Tensor
- Upper bound residual. (batch_size, ngen)
def thermal_residual(self,
pf: torch.Tensor,
pt: torch.Tensor,
qf: torch.Tensor,
qt: torch.Tensor,
clamp: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]-
Calculate the thermal limit residual.
g_{\text{thrm}_1} = \text{pf}^2 + \text{qf}^2 - \text{s1max} g_{\text{thrm}_2} = \text{pt}^2 + \text{qt}^2 - \text{s2max}
Args
pf
:Tensor
- Active power flow from bus per branch. (batch_size, nbranch)
pt
:Tensor
- Active power flow to bus per branch. (batch_size, nbranch)
qf
:Tensor
- Reactive power flow from bus per branch. (batch_size, nbranch)
qt
:Tensor
- Reactive power flow to bus per branch. (batch_size, nbranch)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Thermal limit residual for from branch. (batch_size, nbranch)
Tensor
- Thermal limit residual for to branch. (batch_size, nbranch)
def vm_bound_residual(self, vm: torch.Tensor, clamp: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]
-
Calculate the voltage magnitude bound residual.
g_{\text{lower}} = \text{vmin} - \text{vm} g_{\text{upper}} = \text{vm} - \text{vmax}
Args
vm
:Tensor
- Voltage magnitude per bus. (batch_size, nbus)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Lower bound residual. (batch_size, nbus)
Tensor
- Upper bound residual. (batch_size, nbus)
Inherited members
class BoundRepair (xmin: torch.Tensor | None,
xmax: torch.Tensor | None,
method: str = 'relu',
sanity_check: bool = True,
memory_efficient: int = 0)-
An activation function that clips the output to a given range.
Initializes the BoundRepair module.
If both
xmin
andxmax
are None, per-sample bounds must be provided as input to the forward method. In this casememory_efficient
is ignored (it is set to 2 regardless).Args
xmin
:Tensor
- Lower bounds for clipping.
xmax
:Tensor
- Upper bounds for clipping.
method
:str
- The method to use for clipping. One of ["relu", "sigmoid", "clamp", "softplus", "tanh", "none"].
sanity_check
:bool
- If True, performs sanity checks on the input.
memory_efficient
:int
- 0: pre-compute masks and pre-index bounds, 1: pre-compute masks, 2: do not pre-compute anything
Ancestors
- torch.nn.modules.module.Module
Class variables
var SUPPORTED_METHODS
Static methods
def double_relu(x: torch.Tensor, xmin: torch.Tensor, xmax: torch.Tensor)
-
ReLU bound repair function for double-sided bounds.
\text{relu}(x - \underline{x}) - \text{relu}(x - \overline{x}) + \underline{x}
Args
x
:Tensor
- Input tensor.
xmin
:Tensor
- Lower bound.
xmax
:Tensor
- Upper bound.
Returns
Tensor
- Output tensor satisfying the bounds.
def double_sigmoid(x: torch.Tensor, xmin: torch.Tensor, xmax: torch.Tensor)
-
Sigmoid bound repair function for double-sided bounds.
\text{sigmoid}(x) \cdot (\overline{x} - \underline{x}) + \underline{x}
Args
x
:Tensor
- Input tensor.
xmin
:Tensor
- Lower bound.
xmax
:Tensor
- Upper bound.
Returns
Tensor
- Output tensor satisfying the bounds.
def double_softplus(x: torch.Tensor, xmin: torch.Tensor, xmax: torch.Tensor)
-
Softplus bound repair function for double-sided bounds.
\text{softplus}(x - \underline{x}) - \text{softplus}(x - \overline{x}) + \underline{x}
Args
x
:Tensor
- Input tensor.
xmin
:Tensor
- Lower bound.
xmax
:Tensor
- Upper bound.
Returns
Tensor
- Output tensor satisfying the bounds.
def double_tanh(x: torch.Tensor, xmin: torch.Tensor, xmax: torch.Tensor)
-
Tanh bound repair function for double-sided bounds.
(\frac{1}{2} \tanh(x) + \frac{1}{2}) \cdot (\overline{x} - \underline{x}) + \underline{x}
Args
x
:Tensor
- Input tensor.
xmin
:Tensor
- Lower bound.
xmax
:Tensor
- Upper bound.
Returns
Tensor
- Output tensor satisfying the bounds.
def lower_relu(x: torch.Tensor, xmin: torch.Tensor)
-
ReLU bound repair function for lower bounds.
\text{relu}(x - \underline{x}) + \underline{x}
Args
x
:Tensor
- Input tensor.
xmin
:Tensor
- Lower bound.
Returns
Tensor
- Output tensor satisfying the bounds.
def lower_softplus(x: torch.Tensor, xmin: torch.Tensor)
-
Softplus bound repair function for lower bounds.
\text{softplus}(x - \underline{x}) + \underline{x}
Args
x
:Tensor
- Input tensor.
xmin
:Tensor
- Lower bound.
Returns
Tensor
- Output tensor satisfying the bounds.
def upper_relu(x: torch.Tensor, xmax: torch.Tensor)
-
ReLU bound repair function for upper bounds.
-\text{relu}(\overline{x} - x) + \overline{x}
Args
x
:Tensor
- Input tensor.
xmax
:Tensor
- Upper bound.
Returns
Tensor
- Output tensor satisfying the bounds.
def upper_softplus(x: torch.Tensor, xmax: torch.Tensor)
-
Softplus bound repair function for upper bounds.
-\text{softplus}(\overline{x} - x) + \overline{x}
Args
x
:Tensor
- Input tensor.
xmax
:Tensor
- Upper bound.
Returns
Tensor
- Output tensor satisfying the bounds.
Methods
def clamp(self, x: torch.Tensor)
-
Bound repair function that uses
torch.clamp
.\text{clamp}(x, \underline{x}, \overline{x})
Args
x
:Tensor
- Input tensor.
Returns
Tensor
- Output tensor satisfying the bounds.
def forward(self,
x: torch.Tensor,
xmin: torch.Tensor | None = None,
xmax: torch.Tensor | None = None) ‑> Callable[..., Any]-
Applies the bound clipping function to the input.
def load_state_dict(self, state_dict: dict, strict: bool = True)
-
Loads the state dictionary and re-initializes the pre-computed quantities.
def none(self, x: torch.Tensor)
-
no-op, just return x
def preprocess_bounds(self, memory_efficient: int)
-
Pre-computes masks and pre-indexes bounds depending on
memory_efficient
level.Args
memory_efficient (int):
0
: (fastest, most memory) pre-compute masks and index bounds1
: pre-compute masks only2
: (slowest, least memory) do not pre-compute anything def relu(self, x: torch.Tensor)
-
Apply the ReLU-based bound repair functions to the input, supporting any combination of single- or double-sided bounds.
Args
x
:Tensor
- Input tensor.
Returns
Tensor
- Output tensor satisfying the bounds.
def sigmoid(self, x: torch.Tensor)
-
Apply the sigmoid bound repair function to the input, supporting only unbounded or double-sided bounds.
Args
x
:Tensor
- Input tensor.
Returns
Tensor
- Output tensor satisfying the bounds.
def softplus(self, x: torch.Tensor)
-
Apply the softplus bound-clipping function to the input, supporting any combination of single- or double-sided bounds.
Args
x
:Tensor
- Input tensor.
Returns
Tensor
- Output tensor satisfying the bounds.
def tanh(self, x: torch.Tensor)
-
Apply the tanh bound-clipping function to the input, supporting only unbounded or double-sided bounds.
Args
x
:Tensor
- Input tensor.
Returns
Tensor
- Output tensor satisfying the bounds.
class DCModel
-
OPFModel
for DCOPFAncestors
- OPFModel
- abc.ABC
Subclasses
Class variables
var problem : DCProblem
var violation : DCViolation
Methods
def evaluate_model(self, reduction: str | None = None, inner_reduction: str | None = None) ‑> dict[str, torch.Tensor]
-
Evaluate the model on the test data.
Args
reduction
:str
, optional- Reduction method for the metrics. Defaults to None. Must be one of "mean", "sum","max", "none". If specified, each value in the returned dictionary will be a scalar. Otherwise, they are arrays of shape (n_test_samples,)
inner_reduction
:str
, optional- Reduction method for turning metrics calculated per component to per sample. Defaults to None. Must be one of "mean", "sum","max", "none".
Returns
dict[str, Tensor]
-
Dictionary containing Tensor metrics of the model's performance.
pg_lower
: Generator lower bound violation.pg_upper
: Generator upper bound violation.dva_lower
: Angle difference limit lower bound violation.dva_upper
: Angle difference limit upper bound violation.pf_lower
: Flow limit lower bound violation.pf_upper
: Flow limit upper bound violation.p_balance
: Power balance violation.pg_mae
: Mean absolute error of the real power generation.va_mae
: Mean absolute error of the voltage angle. (if not bus-wise and va not in predictions, skipped)pf_mae
: Mean absolute error of the real power flow.obj_mape
: Mean absolute percent error of the objective value.
def predict(self, pd: torch.Tensor) ‑> dict[str, torch.Tensor]
-
Predict the DCOPF primal solution for a given set of loads.
Args
pd
:Tensor
- Active power demand per load.
Returns
dict[str, Tensor]
-
Dictionary containing the predicted primal solution.
pg
: Active power generation per generator or per bus.va
: Voltage angle per bus.
Inherited members
class DCProblem (data_directory: str, dataset_name: str = 'DCOPF', **parse_kwargs)
-
OPFProblem
for DCOPFAncestors
- OPFProblem
- abc.ABC
Instance variables
prop default_combos : dict[str, list[str]]
-
Default combos for DCOPF:
-
input: pd
-
target: pg, va
-
prop default_order : list[str]
-
Default order for DCOPF. input, target
prop feasibility_check : dict[str, str]
-
Default feasibility check for DCOPF:
-
termination_status: "OPTIMAL"
-
primal_status: "FEASIBLE_POINT"
-
dual_status: "FEASIBLE_POINT"
-
prop violation : DCViolation
-
OPFViolation object for DCOPF constraint calculations.
Inherited members
class DCViolation (data: dict[str, torch.Tensor])
-
OPFViolation
for DCPPowerModel/DCOPFInitialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- OPFViolation
- torch.nn.modules.module.Module
- abc.ABC
Methods
def angle_difference(self, va: torch.Tensor) ‑> torch.Tensor
-
Compute the angle differences per branch given the voltage angles per bus.
The branch indices are assumed to be constant for the batch, matching the reference case.
\text{dva} = \boldsymbol{\theta}_{f} - \boldsymbol{\theta}_{t}
Args
va
:Tensor
- Voltage angles per bus ( \boldsymbol{\theta} ). (batch_size, nbus)
Returns
Tensor
- Angle differences per branch. (batch_size, nbranch)
def balance_residual(self,
pd: torch.Tensor,
pg: torch.Tensor,
pf: torch.Tensor,
embed_method: str = 'pad',
clamp: bool = True) ‑> torch.Tensor-
Compute power balance residual.
\text{g_balance} = \text{pg_bus} - \text{pd_bus} - \text{gs_bus} - \text{pf_bus} - \text{pt_bus}
Args
pd
:Tensor
- Power demand per bus. (batch_size, nbus)
pg
:Tensor
- Power generation per generator. (batch_size, ngen)
pf
:Tensor
- Power flow per branch. (batch_size, nbranch)
embed_method
:str
, optional- Embedding method to convert component-wise values to bus-wise – one of "pad", "dense_matrix", or "matrix". Defaults to "pad".
clamp
:bool
, optional- Clamp to extract only violations. Defaults to True.
Returns
Tensor
- Power balance residual. (batch_size, nbus)
def calc_violations(self,
pd: torch.Tensor,
pg: torch.Tensor,
va: torch.Tensor,
pf: torch.Tensor | None = None,
reduction: str = 'mean',
clamp: bool = True,
embed_method: str = 'pad') ‑> dict[str, torch.Tensor]-
Compute all DCOPF violations.
Args
pd
:Tensor
- Power demand per bus. (batch_size, nbus)
pg
:Tensor
- Power generation per generator. (batch_size, ngen)
va
:Tensor
- Voltage angles per bus. (batch_size, nbus)
pf
:Tensor
, optional- Power flow per branch. Defaults to None.
reduction
:str
, optional- Reduction method. Defaults to "mean".
clamp
:bool
, optional- Clamp to extract only violations. Defaults to True.
embed_method
:str
, optional- Method to convert component-wise values to bus-wise – one of "pad", "dense_matrix", or "matrix". Defaults to "pad".
Returns
dict[str, Tensor]
- Dictionary of all violations:
-
"pg_lower": Lower bound violation of power generation. (batch_size, ngen)
-
"pg_upper": Upper bound violation of power generation. (batch_size, ngen)
-
"dva_lower": Lower bound violation of voltage angle difference. (batch_size, nbranch)
-
"dva_upper": Upper bound violation of voltage angle difference. (batch_size, nbranch)
-
"pf_lower": Lower bound violation of power flow. (batch_size, nbranch)
-
"pf_upper": Upper bound violation of power flow. (batch_size, nbranch)
-
"p_balance": Power balance violation. (batch_size, nbus)
-
"ohm": Ohm's law violation. (batch_size, nbranch)
def dva_bound_residual(self, dva: torch.Tensor, clamp: bool = False) ‑> torch.Tensor
-
Calculate the voltage angle difference bound residual.
g_{\text{lower}} = \text{angmin} - \text{dva} g_{\text{upper}} = \text{dva} - \text{angmax}
Args
dva
:Tensor
- Voltage angle differences per branch. (batch_size, nbranch)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Lower bound residual. (batch_size, nbranch)
Tensor
- Upper bound residual. (batch_size, nbranch)
def objective(self, pg: torch.Tensor) ‑> torch.Tensor
-
Compute DCOPF objective function.
Cost is assumed to be constant for the batch, matching the reference case.
\text{objective} = \sum_i^n \text{cost}_{2,i} + \text{cost}_{1,i} \cdot \text{pg}_i
Args
pg
:Tensor
- Power generation per generator. (batch_size, ngen)
Returns
Tensor
- Objective function value. (batch_size, 1)
def ohm_residual(self, pf: torch.Tensor, dva: torch.Tensor, clamp: bool = False) ‑> torch.Tensor
-
Compute Ohm's law violation.
\text{g_ohm} = - b \cdot \text{dva} - \text{pf}
Args
pf
:Tensor
- Power flow per branch. (batch_size, nbranch)
dva
:Tensor
- Voltage angle differences per branch. (batch_size, nbranch)
clamp
:bool
, optional- Clamp to extract only violations. Defaults to False.
Returns
Tensor
- Ohm's law violation. (batch_size, nbranch)
def pf_bound_residual(self, pf: torch.Tensor, clamp: bool = False) ‑> torch.Tensor
-
Calculate the power flow bound residual.
g_{\text{lower}} = -\text{rate_a} - \text{pf} g_{\text{upper}} = \text{pf} - \text{rate_a}
Args
pf
:Tensor
- Power flow per branch. (batch_size, nbranch)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Lower bound residual. (batch_size, nbranch)
Tensor
- Upper bound residual. (batch_size, nbranch)
def pf_from_va(self, va: torch.Tensor) ‑> torch.Tensor
-
Compute power flow given voltage angles.
\mathbf{p}_f = -\text{b} \cdot (\boldsymbol{\theta}_{f} - \boldsymbol{\theta}_{t})
Args
va
:Tensor
- Voltage angles per bus ( \boldsymbol{\theta} ). (batch_size, nbus)
Returns
Tensor
- Power flow per branch ( \mathbf{p}_f ). (batch_size, nbranch)
def pg_bound_residual(self, pg: torch.Tensor, clamp: bool = False) ‑> torch.Tensor
-
Calculate the power generation bound residual.
g_{\text{lower}} = \text{pmin} - \text{pg} g_{\text{upper}} = \text{pg} - \text{pmax}
Args
pg
:Tensor
- Active power generation per generator. (batch_size, ngen)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Lower bound residual. (batch_size, ngen)
Tensor
- Upper bound residual. (batch_size, ngen)
Inherited members
class EDModel
-
OPFModel
for EconomicDispatchAncestors
- OPFModel
- abc.ABC
Subclasses
Class variables
var problem : EDProblem
var violation : EDViolation
Methods
def evaluate_model(self, reduction: str | None = None, inner_reduction: str | None = None) ‑> dict[str, torch.Tensor]
-
Evaluate the model on the test data.
Args
reduction
:str
, optional- Reduction method for the metrics. Defaults to None. Must be one of "mean", "sum","max", "none". If specified, each value in the returned dictionary will be a scalar. Otherwise, they are arrays of shape (n_test_samples,)
inner_reduction
:str
, optional- Reduction method for turning metrics calculated per component to per sample. Defaults to None. Must be one of "mean", "sum","max", "none".
Returns
dict[str, Tensor]
-
Dictionary containing Tensor metrics of the model's performance.
pg_lower
: Generator lower bound violation.pg_upper
: Generator upper bound violation.pf_lower
: Branch power flow lower bound violation.pf_upper
: Branch power flow upper bound violation.p_balance
: Power balance violation.pg_mae
: Mean absolute error of the real power generation.obj_mape
: Mean absolute percent error of the objective value.
Inherited members
class EDProblem (data_directory: str, ptdf_path: str, dataset_name: str = 'ED', **parse_kwargs)
-
OPFProblem
for EconomicDispatchAncestors
- OPFProblem
- abc.ABC
Instance variables
prop default_combos : dict[str, list[str]]
-
Default combos for EconomicDispatch. input: pd, target: pg, va
prop default_order : list[str]
-
Default order for EconomicDispatch. input, target
prop feasibility_check : dict[str, str]
-
Default feasibility check for EconomicDispatch.
prop violation : EDViolation
-
OPFViolation object for EconomicDispatch constraint calculations.
Inherited members
class EDViolation (data: dict, ptdf: torch.Tensor)
-
OPFViolation
for EconomicDispatchInitialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- OPFViolation
- torch.nn.modules.module.Module
- abc.ABC
Methods
def balance_residual(self,
pd: torch.Tensor,
pg: torch.Tensor,
dpb_surplus: torch.Tensor | None = None,
dpb_shortage: torch.Tensor | None = None,
clamp: bool = False) ‑> torch.Tensor-
Compute power balance residual.
def calc_violations(self,
pd: torch.Tensor,
pg: torch.Tensor,
reduction: str = 'mean',
clamp: bool = True) ‑> dict[str, torch.Tensor]-
Compute all EconomicDispatch violations.
def objective(self, pd: torch.Tensor, pg: torch.Tensor) ‑> torch.Tensor
-
Compute ED objective function.
def pf_bound_residual(self, pf: torch.Tensor, df: torch.Tensor | None = None, clamp: bool = False) ‑> torch.Tensor
-
Compute power flow bound residual.
def pf_from_pdpg(self, pd: torch.Tensor, pg: torch.Tensor, dense_incidence: bool = False) ‑> torch.Tensor
-
Compute power flow from power demand and power generation.
def pg_bound_residual(self, pg: torch.Tensor, clamp: bool = False) ‑> torch.Tensor
-
Compute power generation bound residual.
Inherited members
class HyperSimplexRepair (xmin: torch.Tensor | None = None,
xmax: torch.Tensor | None = None,
X: torch.Tensor | None = None)-
Repair layer for the hyper-simplex constraint ∑x=X, x̲≤x≤x̅.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- torch.nn.modules.module.Module
Methods
def forward(self,
x: torch.Tensor,
xmin: torch.Tensor | None = None,
xmax: torch.Tensor | None = None,
X: torch.Tensor | None = None) ‑> Callable[..., Any]-
Project onto ∑x=X, x̲≤x≤x̅
class LDFLoss (v: OPFViolation,
step_size: float,
kickin: int,
update_freq: int,
divide_by_counter: bool = True,
exclude_keys: str | list[str] | None = None)-
LDFLoss implements the Lagrangian Dual Framework.
exclude_keys
is either None to use all violations, "all" to skip all violations, or a list of keys to skip specific violations.Initialize LDFLoss module.
Ancestors
- torch.nn.modules.module.Module
Methods
def end_epoch(self)
-
Call this method at the end of each epoch.
def forward(self,
base_loss: torch.Tensor,
exclude_keys: str | list[str] | None = None,
**calc_violation_inputs: torch.Tensor) ‑> torch.Tensor-
Compute the LDF Loss for a batch of samples.
def init_mults(self, shapes=None)
-
Initialize λ and trackers to zeros.
def reset_trackers(self)
-
Reset the violation trackers to zeros.
def start_epoch(self, epoch) ‑> str | None
-
Call this method at the start of each epoch.
def update(self)
-
Update the lagrangian dual multipliers (λ)
class OPFModel
-
An abstract base class for ACOPF models.
Ancestors
- abc.ABC
Subclasses
Static methods
def load_from_checkpoint(path_to_folder: str,
problem: OPFProblem)-
Load the model's checkpoint from a file.
Args
path
:str
- Path to load the checkpoint from.
Methods
def evaluate_model(self, reduction: str | None = None, inner_reduction: str | None = None) ‑> dict[str, torch.Tensor]
-
Evaluate the model on the test data.
Args
reduction
:str
, optional- Reduction method for the metrics. Defaults to None. Must be one of "mean", "sum", "none". If specified, each value in the returned dictionary will be a scalar. Otherwise, they are arrays of shape (n_test_samples,)
Returns
dict[str, Tensor]
- Dictionary containing Tensor metrics of the model's performance.
def predict(self, *inputs: torch.Tensor) ‑> dict[str, torch.Tensor]
-
Predict the solution for a given set of inputs.
Args
*inputs
:Tensor
- Input tensors to the model.
Returns
dict[str, Tensor]
- Dictionary containing the solution.
def save_checkpoint(self, path_to_folder: str)
-
Save the model's checkpoint to a file.
Args
path
:str
- Path to save the checkpoint.
class OPFProblem (data_directory: str, dataset_name: str, **parse_kwargs)
-
OPF Problem
This class parses the JSON/HDF5 files on initialization, providing a standard interface for accessing OPF data.
OPFProblem also includes methods for creating input/target tensors from the HDF5 data by concatenating keys, though more complex datasets (e.g., for graph neural networks) can be created by accessing
train_data
andjson_data
directly.By default, initializing OPFProblem will parse the HDF5/JSON files, remove infeasible samples, and set aside 5000 samples for testing. The test data can be accessed via
test_data
-train_data
will only contain the training data. Models should split the training data into training/validation sets themselves downstream.Initialization Arguments:
-
data_directory (str)
: Path to the folder containing the problem files -
dataset_name (str)
: Name of the problem to use -
primal (bool)
: Whether to parse the primal data (default: True) -
dual (bool)
: Whether to parse the dual data (default: True) -
train_set (bool)
: Whether to parse the training set (default: True) -
test_set (bool)
: Whether to parse the test set (default: True) -
convert_to_float32 (bool)
: Whether to convert the data to float32 (default: True) -
sanity_check (bool)
: Whether to perform a sanity check on the parsed data (default: True)
Attributes:
-
path (Path)
: Path to the problem file folder -
name (str)
: Name of the problem to use -
train_data (dict)
: Dictionary of parsed HDF5 data. Ifmake_test_set
is True, this is only the training set. -
test_data (dict)
: Dictionary of parsed HDF5 data for the test set. Ifmake_test_set
is False, this is None. -
json_data (dict)
: Dictionary of parsed JSON data. -
violation (OPFViolation)
: OPFViolation object for computing constraint violations for this problem.
Methods:
-
parse
: Parse the JSON and HDF5 files for the problem -
make_dataset
: Create input/target tensors by combining keys from the h5 data. Returns the TensorDataset and slices for extracting the original components. -
slice_batch
: Extract the original components from a batch of data given the slices. -
slice_tensor
: Extract the original components from a tensor given the slices.
Ancestors
- abc.ABC
Subclasses
Static methods
def slice_batch(batch: tuple[torch.Tensor, ...], slices: list[dict[str, slice]])
-
Slice the batch tensors into the original tensors
Args
batch
:tuple[Tensor, …]
- Batch of tensors from the TensorDataset
slices
:list[dict[str, slice]]
- List of dictionaries of slices
Returns
tuple[dict[str, Tensor], …]
- Sliced tensors
def slice_tensor(tensor: torch.Tensor, slices: dict[str, slice])
-
Slice the tensor into the original tensors
Args
tensor
:Tensor
- Tensor to slice
slices
:dict[str, slice]
- Dictionary of slices
Returns
dict[str, Tensor]
: Sliced tensors
Instance variables
prop default_combos : dict[str, list[str]]
-
A dictionary where keys represent elements of the tuple from the TensorDataset and values are keys of the train_data dictionary which are concatenated. Used by
make_dataset
. prop default_order : list[str]
-
The order of the keys in the default_combos dictionary.
prop feasibility_check : dict[str, str]
-
Dictionary of keys and values to check feasibility of the problem.
Each key is checked to have the corresponding value. If any of them does not match, the sample is removed from the dataset in
PGLearnParser
. See ACOPFProblem.feasibility_check for an example.
Methods
def make_dataset(self,
combos: dict[str, list[str]] | None = None,
order: list[str] | None = None,
data: dict[str, torch.Tensor] | None = None,
test_set: bool = False,
sanity_check: bool = True) ‑> tuple[dict[str, torch.Tensor], list[dict[str, slice]]]-
Make a TensorDataset from self.train_data given the keys in combos and the order of the keys in order.
def parse(self,
primal: bool = True,
dual: bool = True,
train_set: bool = True,
test_set: bool = True,
convert_to_float32: bool = True,
sanity_check: bool = True)-
Parse the JSON and HDF5 files for the problem
-
class OPFViolation (data: dict[str, torch.Tensor])
-
The
OPFViolation
class is where all the problem expressions (objective, constraints, etc.) are defined. The classes provide a convenient interface for when the only varying quantity in the formulation per sample is the load demandpd
. If other quantities vary, the user should use the functional interface atml4opf.functional
.When
clamp
is true, the values of g(x) are clamped to be non-negative and the values of h(x) are absolute-valued. Otherwise the raw values are returned.OPFViolation
is atorch.nn.Module
; all tensors used in computation are registered as non-persistent buffers. To move the module to a different device, use.to(device)
as you would with any othernn.Module
. Make sure that when you pass data toOPFViolation
, it is on the same device asOPFViolation
.Initialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- torch.nn.modules.module.Module
- abc.ABC
Subclasses
Class variables
var SUPPORTED_REDUCTIONS
Static methods
def clamped_bound_residual(x: torch.Tensor, xmin: torch.Tensor, xmax: torch.Tensor, clamp: bool)
def reduce_violation(violation: torch.Tensor, reduction: str, dim: int = 1)
-
Apply mean/sum/max to a single tensor.
def reduce_violations(violations: dict[str, torch.Tensor], reduction: str, dim: int = 1)
-
Apply mean/sum/max to every value in a dictionary.
Instance variables
prop adjacency_matrix : torch.Tensor
-
Sparse adjacency matrix.
Each row corresponds to a bus and each column corresponds to a bus. The value is 1 if there is a branch between the buses, and 0 otherwise.
Args
fbus
:Tensor
- From bus indices. (nbranch,)
tbus
:Tensor
- To bus indices. (nbranch,)
n_bus
:int
- Number of buses.
n_branch
:int
- Number of branches.
Returns
Tensor
- Sparse adjacency matrix. (nbus, nbus)
prop adjacency_matrix_dense : torch.Tensor
prop branch_from_incidence : torch.Tensor
-
Sparse branch from incidence matrix.
Each row corresponds to a bus and each column corresponds to a branch. The value is 1 if the branch is from the bus, and 0 otherwise.
Returns
Tensor
- Sparse branch from incidence matrix. (nbus, nbranch)
prop branch_from_incidence_dense : torch.Tensor
prop branch_incidence : torch.Tensor
-
Sparse branch incidence matrix.
Each row corresponds to a bus and each column corresponds to a branch. The value is 1 if the branch is from the bus, -1 if the branch is to the bus, and 0 otherwise.
Returns
Tensor
- Sparse branch incidence matrix. (nbus, nbranch)
prop branch_incidence_dense : torch.Tensor
prop branch_to_incidence : torch.Tensor
-
Sparse branch to incidence matrix.
Each row corresponds to a bus and each column corresponds to a branch. The value is 1 if the branch is to the bus, and 0 otherwise.
Returns
Tensor
- Sparse branch to incidence matrix. (nbus, nbranch)
prop branch_to_incidence_dense : torch.Tensor
prop generator_incidence : torch.Tensor
-
Sparse generator incidence matrix.
Each row corresponds to a bus and each column corresponds to a generator. The value is 1 if the generator is at the bus, and 0 otherwise.
Returns
Tensor
- Sparse generator incidence matrix. (nbus, ngen)
prop generator_incidence_dense : torch.Tensor
prop load_incidence : torch.Tensor
-
Sparse load incidence matrix.
Each row corresponds to a bus and each column corresponds to a load. The value is 1 if the load is at the bus, and 0 otherwise.
Returns
Tensor
- Sparse load incidence matrix. (nbus, nload)
prop load_incidence_dense : torch.Tensor
Methods
def branch_from_to_bus(self, pf_or_qf: torch.Tensor, method: str = 'pad') ‑> torch.Tensor
-
Embed batched branch-wise values to batched bus-wise.
The default method "pad" sums over any flows on branches from the same bus. The matrix methods "dense_matrix" and "matrix" use the incidence matrix.
Args
pf_or_qf
:Tensor
- Branch-wise values. (batch, nbranch)
method
:str
- Method to use. Supported: ['pad', 'dense_matrix', 'matrix']
Returns
Tensor
- Bus-wise values. (batch, nbus)
def branch_to_to_bus(self, pt_or_qt: torch.Tensor, method: str = 'pad') ‑> torch.Tensor
-
Embed batched branch-wise values to batched bus-wise.
The default method "pad" sums over any flows on branches to the same bus. The matrix methods "dense_matrix" and "matrix" use the incidence matrix.
Args
pt_or_qt
:Tensor
- Branch-wise values. (batch, nbranch)
method
:str
- Method to use. Supported: ['pad', 'dense_matrix', 'matrix']
Returns
Tensor
- Bus-wise values. (batch, nbus)
def calc_violations(self, *args, reduction: str = 'mean', clamp: bool = True) ‑> dict[str, torch.Tensor]
-
Calculate the violations of the constraints. Returns a dictionary of tensors.
def forward(self, *args, **kwargs) ‑> Callable[..., Any]
-
Pass-through for
OPFViolation.calc_violations()
def gen_to_bus(self, pg_or_qg: torch.Tensor, method: str = 'pad') ‑> torch.Tensor
-
Embed generator-wise values to bus-wise.
The default method "pad" sums over any generators at the same bus. The matrix methods "dense_matrix" and "matrix" use the incidence matrix.
Args
pg_or_qg
:Tensor
- Generator-wise values. (batch, ngen)
method
:str
- Method to use. Supported: ['pad', 'dense_matrix', 'matrix']
Returns
Tensor
- Bus-wise values. (batch, nbus)
def load_to_bus(self, pd_or_qd: torch.Tensor, method: str = 'pad') ‑> torch.Tensor
-
Embed load-wise values to bus-wise.
The default method "pad" sums over any loads at the same bus. The matrix methods "dense_matrix" and "matrix" use the incidence matrix.
Args
pd_or_qd
:Tensor
- Load-wise values. (batch, ngen)
method
:str
- Method to use. Supported: ['pad', 'dense_matrix', 'matrix']
Returns
Tensor
- Bus-wise values. (batch, nbus)
def objective(self, *args) ‑> torch.Tensor
-
Compute the objective value for a batch of samples.
def violation_shapes(self) ‑> dict[str, int]
-
Return the shapes of the violations returned by
OPFViolation.calc_violations()
.
class ObjectiveLoss (v: OPFViolation,
reduction: str | None = 'mean')-
ObjectiveLoss is the original objective of the OPF.
It takes as input the same arguments as the corresponding formulation's
objective
method, and returns the objective value.Initialize ObjectiveLoss module.
Args
v
:OPFViolation
- OPFViolation module.
reduction
:Optional[str]
- Reduction operation. Default: "mean".
Ancestors
- torch.nn.modules.module.Module
Class variables
var SUPPORTED_REDUCTIONS
Methods
def forward(self, *objective_args, **objective_kwargs) ‑> torch.Tensor
-
Compute the objective value for a batch of samples.
class PGLearnParser (data_path: str | pathlib.Path)
-
Parser for PGLearn dataset.
Initialize the parser by validating and setting the path.
Class variables
var padval
Static methods
def convert_to_float32(dat: dict[str, torch.Tensor | numpy.ndarray | numpy.str_])
-
Convert all float64 data to float32 in-place.
def make_tree(dat: dict[str, torch.Tensor | numpy.ndarray | numpy.str_],
delimiter: str = '/')-
Convert a flat dictionary to a tree. Note that the keys of
dat
must have a tree structure where data is only at the leaves. Assumes keys are delimited by "/", i.e. "solution/primal/pg".Args
dat
:dict
- Flat dictionary of data.
delimiter
:str
, optional- Delimiter to use for splitting keys. Defaults to "/".
Returns
dict
- Tree dictionary of data from
dat
.
def pad_to_dense(array, padval, dtype=builtins.int)
Methods
def open_json(self)
-
Open the JSON file, supporting gzip and bz2 compression based on the file suffix.
def parse_h5(self,
dataset_name: str,
split: str = 'train',
primal: bool = True,
dual: bool = False,
convert_to_float32: bool = True) ‑> dict[str, torch.Tensor | numpy.ndarray | numpy.str_] | tuple[dict[str, torch.Tensor | numpy.ndarray | numpy.str_], dict[str, torch.Tensor | numpy.ndarray | numpy.str_]]-
Parse the HDF5 file.
Args
dataset_name
:str
- The name of the dataset. Typically the formulation ("ACOPF", "DCOPF", etc.).
split
:str
, optional- The split to return. Defaults to "train".
primal
:bool
, optional- If True, parse the primal file. Defaults to True.
dual
:bool
, optional- If True, parse the dual file. Defaults to False.
convert_to_float32
:bool
, optional- If True, convert all float64 data to torch.float32. Defaults to True.
Returns
dict
- Flattened dictionary of HDF5 data with PyTorch tensors for numerical data and NumPy arrays for string/object data.
If
make_test_set
is True, then this function will return a tuple of two dictionaries. The first dictionary is the training set and the second dictionary is the test set. The test set is a random 10% sample of the training set.This parser will return a single-level dictionary where the keys are in the form of
solution/primal/pg
wheresolution
is the group,primal
is the subgroup, andpg
is the dataset from the HDF5 file. The values are PyTorch tensors. This parser usesh5py.File.visititems
to iterate over the HDF5 file quickly. def parse_json(self, model_type: str | Sequence[str] = None)
-
Parse the JSON file from PGLearn.
Args
model_type
:Union[str, Sequence[str]]
- The reference solutions to save. Default: [] (no reference solutions saved.)
Returns
dict
- Dictionary containing the parsed data.
In the JSON file, the data is stored by each individual component. So to get generator 1's upper bound on active generation, you'd look at: raw_json['data']['gen']['1']['pmax'] and get a float.
In the parsed version, we aggregate each of the components attributes into torch.Tensor arrays. So to get generator 1's upper bound on active generation, you'd look at: dat['gen']['pmax'][0] and get a float. Note that the index is 0-based and an integer, not 1-based and a string.
To access the reference solution, pass a model_type (or multiple) and then access dat["ref_solutions"][model_type].
def validate_path(self, path: str | pathlib.Path) ‑> pathlib.Path
-
Validate the path to the HDF5 file.
class SOCModel
-
OPFModel
for SOCOPFAncestors
- OPFModel
- abc.ABC
Subclasses
Class variables
var problem : SOCProblem
var violation : SOCViolation
Methods
def predict(self, pd: torch.Tensor, qd: torch.Tensor) ‑> dict[str, torch.Tensor]
-
Predict the SOCOPF primal solution for a given set of loads.
Args
pd
:Tensor
- Active power demand per load.
qd
:Tensor
- Reactive power demand per load.
Returns
dict[str, Tensor]
-
Dictionary containing the predicted primal solution.
pg
: Active power generation per generator or per bus.qg
: Reactive power generation per generator or per bus.w
: Squared voltage magnitude per bus.wr
: Real part of the voltage phasor.wi
: Imaginary part of the voltage phasor.
Inherited members
class SOCProblem (data_directory: str, dataset_name: str = 'SOCOPF', **parse_kwargs)
-
OPFProblem
for SOCOPFAncestors
- OPFProblem
- abc.ABC
Instance variables
prop default_combos : dict[str, list[str]]
-
Default combos for SOCOPF:
-
input: pd, qd
-
target: pg, qg, w, wr, wi
-
prop default_order : list[str]
-
Default order for SOCOPF: input, target
prop feasibility_check : dict[str, str]
-
Default feasibility check for SOCOPF:
-
termination_status: "LOCALLY_SOLVED"
-
primal_status: "FEASIBLE_POINT"
-
dual_status: "FEASIBLE_POINT"
-
prop violation : SOCViolation
-
SOCViolation
object, created upon first access.
Inherited members
class SOCViolation (data: dict)
-
OPFViolation
for SOCOPFInitialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- OPFViolation
- torch.nn.modules.module.Module
- abc.ABC
Methods
def angle_difference_residual(self, wr: torch.Tensor, wi: torch.Tensor, clamp: bool = False) ‑> torch.Tensor
-
Compute the angle difference bound residual.
def balance_residual(self,
pd: torch.Tensor,
qd: torch.Tensor,
pg: torch.Tensor,
qg: torch.Tensor,
w: torch.Tensor,
pf: torch.Tensor,
pt: torch.Tensor,
qf: torch.Tensor,
qt: torch.Tensor,
clamp: bool = False,
embed_method: str = 'pad') ‑> tuple[torch.Tensor, torch.Tensor]-
Calculate the power balance residual.
Component-wise tensors are first embedded to the bus level using
embed_method
.The shunt parameters g_s, b_s are assumed to be constant, matching the reference case.
\text{p_viol} = \text{pg_bus} - \text{pd_bus} - \text{pt_bus} - \text{pf_bus} - \text{gs_bus} \times \text{vm}^2 \text{q_viol} = \text{qg_bus} - \text{qd_bus} - \text{qt_bus} - \text{qf_bus} + \text{bs_bus} \times \text{vm}^2
Args
pd
:Tensor
- Active power demand per bus. (batch_size, nbus)
qd
:Tensor
- Reactive power demand per bus. (batch_size, nbus)
pg
:Tensor
- Active power generation per generator. (batch_size, ngen)
qg
:Tensor
- Reactive power generation per generator. (batch_size, ngen)
vm
:Tensor
- Voltage magnitude per bus. (batch_size, nbus)
pf
:Tensor
- Active power flow from bus per branch. (batch_size, nbranch)
pt
:Tensor
- Active power flow to bus per branch. (batch_size, nbranch)
qf
:Tensor
- Reactive power flow from bus per branch. (batch_size, nbranch)
qt
:Tensor
- Reactive power flow to bus per branch. (batch_size, nbranch)
clamp
:bool
, optional- Apply an absolute value to the residual. Defaults to False.
embed_method
:str
, optional- Embedding method for bus-level components. Defaults to 'pad'. Must be one of 'pad', 'dense_matrix', or 'matrix. See
IncidenceMixin.*_to_bus
.
Returns
Tensor
- Power balance residual for active power. (batch_size, nbus)
Tensor
- Power balance residual for reactive power. (batch_size, nbus)
def calc_violations(self,
pd: torch.Tensor,
qd: torch.Tensor,
pg: torch.Tensor,
qg: torch.Tensor,
w: torch.Tensor,
wr: torch.Tensor,
wi: torch.Tensor,
flows: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
reduction: str | None = 'mean',
clamp: bool = True) ‑> dict[str, torch.Tensor]-
Calculate the violation of all the constraints.
The reduction is applied across the component dimension - e.g., 'mean' will do violation.mean(dim=1) where each violation is (batch, components)
Args
pd
:Tensor
- Real power demand. (batch, loads)
qd
:Tensor
- Reactive power demand. (batch, loads)
pg
:Tensor
- Real power generation. (batch, gens)
qg
:Tensor
- Reactive power generation. (batch, gens)
vm
:Tensor
- Voltage magnitude. (batch, buses)
va
:Tensor
, optional- Voltage angle. (batch, buses)
dva
:Tensor
, optional- Voltage angle difference. (batch, branches)
flows
:tuple[Tensor, Tensor, Tensor, Tensor]
, optional- Power flows. (pf, pt, qf, qt)
reduction
:str
, optional- Reduction method. Defaults to 'mean'. Must be one of 'mean', 'sum', 'none'.
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to True.
Returns
- dict[str, Tensor]: Dictionary of violations.
vm_lower
: Voltage magnitude lower bound violation.vm_upper
: Voltage magnitude upper bound violation.pg_lower
: Real power generation lower bound violation.pg_upper
: Real power generation upper bound violation.qg_lower
: Reactive power generation lower bound violation.qg_upper
: Reactive power generation upper bound violation.thrm_1
: Thermal limit from violation.thrm_2
: Thermal limit to violation.p_balance
: Real power balance violation.q_balance
: Reactive power balance violation.dva_lower
: Voltage angle difference lower bound violation.dva_upper
: Voltage angle difference upper bound violation. def flows_from_voltage(self, w: torch.Tensor, wr: torch.Tensor, wi: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
-
Compute the power flows.
Returns
Tensor
- Real power flow per branch ( \mathbf{p}_f ). (batch_size, nbranch)
Tensor
- Real power flow per branch ( \mathbf{p}_t ). (batch_size, nbranch)
Tensor
- Reactive power flow per branch ( \mathbf{q}_f ). (batch_size, nbranch)
Tensor
- Reactive power flow per branch ( \mathbf{q}_t ). (batch_size, nbranch)
def jabr_residual(self, w: torch.Tensor, wr: torch.Tensor, wi: torch.Tensor, clamp: bool = False) ‑> torch.Tensor
-
Compute the Jabr constraint residual.
g_{\text{jabr}} = \text{wr}^2 + \text{wi}^2 - \text{w}_{fr} * \text{w}_{to}
Args
w
:Tensor
- Squared voltage magnitude per bus. (batch_size, nbus)
wr
:Tensor
- Real part of the voltage phasor. (batch_size, nbus)
wi
:Tensor
- Imaginary part of the voltage phasor. (batch_size, nbus)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Jabr constraint residual. (batch_size, nbus)
def objective(self, pg: torch.Tensor) ‑> torch.Tensor
-
Compute the objective function given the active power generation per generator.
Args
pg
:Tensor
- Active power generation per generator. (batch_size, ngen)
Returns
Tensor
- Objective function value. (batch_size)
def pg_bound_residual(self, pg: torch.Tensor, clamp: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]
-
Calculate the active power generation bound residual.
g_{\text{lower}} = \text{pmin} - \text{pg} g_{\text{upper}} = \text{pg} - \text{pmax}
Args
pg
:Tensor
- Active power generation per generator. (batch_size, ngen)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Lower bound residual. (batch_size, ngen)
Tensor
- Upper bound residual. (batch_size, ngen)
def qg_bound_residual(self, qg: torch.Tensor, clamp: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]
-
Calculate the reactive power generation bound residual.
g_{\text{lower}} = \text{qmin} - \text{qg} g_{\text{upper}} = \text{qg} - \text{qmax}
Args
qg
:Tensor
- Reactive power generation per generator. (batch_size, ngen)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Lower bound residual. (batch_size, ngen)
Tensor
- Upper bound residual. (batch_size, ngen)
def thermal_residual(self,
pf: torch.Tensor,
pt: torch.Tensor,
qf: torch.Tensor,
qt: torch.Tensor,
clamp: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]-
Calculate the thermal limit residual.
g_{\text{thrm}_1} = \text{pf}^2 + \text{qf}^2 - \text{s1max} g_{\text{thrm}_2} = \text{pt}^2 + \text{qt}^2 - \text{s2max}
Args
pf
:Tensor
- Active power flow from bus per branch. (batch_size, nbranch)
pt
:Tensor
- Active power flow to bus per branch. (batch_size, nbranch)
qf
:Tensor
- Reactive power flow from bus per branch. (batch_size, nbranch)
qt
:Tensor
- Reactive power flow to bus per branch. (batch_size, nbranch)
clamp
:bool
, optional- Clamp the residual to be non-negative (extract violations). Defaults to False.
Returns
Tensor
- Thermal limit residual for from branch. (batch_size, nbranch)
Tensor
- Thermal limit residual for to branch. (batch_size, nbranch)
def w_bound_residual(self, w: torch.Tensor, clamp: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]
-
Calculate the bound residual of w.
Returns
Tensor
- Lower bound residual. (batch_size, nbranch)
Tensor
- Upper bound residual. (batch_size, nbranch)
def wi_bound_residual(self, wr: torch.Tensor, clamp: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]
-
Calculate the bound residual of wi.
Returns
Tensor
- Lower bound residual. (batch_size, nbranch)
Tensor
- Upper bound residual. (batch_size, nbranch)
def wr_bound_residual(self, wr: torch.Tensor, clamp: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]
-
Calculate the bound residual of wr.
Returns
Tensor
- Lower bound residual. (batch_size, nbranch)
Tensor
- Upper bound residual. (batch_size, nbranch)
Inherited members