Source code for disropt.functions.extended_function

import numpy as np
import warnings
from .abstract_function import AbstractFunction
from .utilities import check_input


[docs]class ExtendedFunction(AbstractFunction): """Function with extended variable .. math:: f(x, y) = x with :math:`x\\in \\mathbb{R}^{n}, y\\in \\mathbb{R}^{m}` Args: fn (AbstractFunction): input function n_var: number of additional variables. Defaults to 1 axis: axis along which the additional variables are appended. Defaults to -1 (the last valid one) pos: position index of the old variable vector. Defaults to 0 Raises: TypeError: fn must be a AbstractFunction TypeError: n_var must be a positive int TypeError: axis must be int """ # TODO implement __new__ to maintain existing structure (e.g. QuadraticForm, AffineForm) def __new__(cls, fn: AbstractFunction, n_var: int=1, axis: int=-1, pos: int=0): instance = object.__new__(cls) if fn.is_affine: # TODO extend to non mono-dimensional case from .variable import Variable from .affine_form import AffineForm A, b = fn.get_parameters() A_new = np.zeros((A.shape[0]+n_var, A.shape[1])) A_new[pos:pos+A.shape[0], :] = A instance = AffineForm(Variable(A.shape[0] + n_var), A_new, b) elif fn.is_quadratic: # TODO extend to non mono-dimensional case from .variable import Variable from .quadratic_form import QuadraticForm P, q, r = fn.get_parameters() P_new = np.pad(P, ((pos, n_var-pos), (pos, n_var-pos)), 'constant') # TODO q_new = np.zeros((q.shape[0]+n_var, q.shape[1])) q_new[pos:pos+q.shape[0], :] = q instance = QuadraticForm(Variable(P.shape[0] + n_var), P_new, q_new, r) return instance def __init__(self, fn: AbstractFunction, n_var: int=1, axis: int=-1, pos: int=0): if not isinstance(fn, AbstractFunction): raise TypeError("fn must be a AbstractFunction object") if not isinstance(n_var, int) or n_var <= 0: raise TypeError("n_var must be a positive int") if not isinstance(axis, int): raise TypeError("axis must be int") # internal variables self.n_var = n_var self.axis = [i for i in range(len(fn.input_shape)) if fn.input_shape[i] > 1][axis] self.pos = pos self.fn = fn self.input_shape_original = fn.input_shape indices = [range(i) for i in fn.input_shape] indices[self.axis] = range(pos, pos+fn.input_shape[self.axis]) self.eval_index = np.ix_(*tuple(indices)) # AbstractFunction variables new_shape = list(fn.input_shape) new_shape[self.axis] += n_var self.input_shape = tuple(new_shape) self.output_shape = fn.output_shape self.differentiable = fn.differentiable self.affine = False self.quadratic = False super().__init__() def _expression(self): expression = 'Extended Function({}, {} additional vars along axis {} old position {})'.format(self.fn._expression(), self.n_var, self.axis, self.pos) return expression @check_input def eval(self, x: np.ndarray) -> np.ndarray: return self.fn.eval(x[self.eval_index]) def _to_cvxpy(self): import cvxpy as cvx return self.fn._extend_variable(self.n_var, self.axis, self.pos)._to_cvxpy()