utils.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import numpy as np
  2. import scipy.signal as scipysig
  3. from numpy.typing import NDArray
  4. from typing import Optional
  5. from pydeepc import Data
  6. class System(object):
  7. """
  8. Represents a dynamical system that can be simulated
  9. """
  10. def __init__(self, sys: scipysig.StateSpace, x0: Optional[NDArray[np.float64]] = None, noise_std: float = 0.5):
  11. """
  12. :param sys: a linear system
  13. :param x0: initial state
  14. :param noise_std: Standard deviation of the measurement noise
  15. """
  16. assert x0 is None or sys.A.shape[0] == len(x0), 'Invalid initial condition'
  17. self.sys = sys
  18. self.x0 = x0 if x0 is not None else np.zeros(sys.A.shape[0])
  19. self.u = None
  20. self.y = None
  21. self.noise_std = noise_std
  22. def apply_input(self, u: NDArray[np.float64]) -> Data:
  23. """
  24. Applies an input signal to the system.
  25. :param u: input signal. Needs to be of shape T x M, where T is the batch size and M is the number of features
  26. :return: tuple that contains the (input,output) of the system
  27. """
  28. T = len(u)
  29. if T > 1:
  30. # If u is a signal of length > 1 use dlsim for quicker computation
  31. t, y, x0 = scipysig.dlsim(self.sys, u, t = np.arange(T) * self.sys.dt, x0 = self.x0)
  32. self.x0 = x0[-1]
  33. else:
  34. y = self.sys.C @ self.x0
  35. self.x0 = self.sys.A @ self.x0.flatten() + self.sys.B @ u.flatten()
  36. y = y + self.noise_std * np.random.normal(size = T).reshape(T, 1)
  37. self.u = np.vstack([self.u, u]) if self.u is not None else u
  38. self.y = np.vstack([self.y, y]) if self.y is not None else y
  39. return Data(u, y)
  40. def get_last_n_samples(self, n: int) -> Data:
  41. """
  42. Returns the last n samples
  43. :param n: integer value
  44. """
  45. assert self.u.shape[0] >= n, 'Not enough samples are available'
  46. return Data(self.u[-n:], self.y[-n:])
  47. def get_all_samples(self) -> Data:
  48. """
  49. Returns all samples
  50. """
  51. return Data(self.u, self.y)
  52. def reset(self, data_ini: Optional[Data] = None, x0: Optional[NDArray[np.float64]] = None):
  53. """
  54. Reset initial state and collected data
  55. """
  56. self.u = None if data_ini is None else data_ini.u
  57. self.y = None if data_ini is None else data_ini.y
  58. self.x0 = x0 if x0 is not None else np.zeros(self.sys.A.shape[0])