deepc_test.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import numpy as np
  2. from scipy.optimize import minimize
  3. # 系统参数
  4. A = 0.8
  5. B = 0.5
  6. C = 1
  7. D = 0.2
  8. # 参考轨迹函数
  9. def reference_trajectory(t):
  10. return np.sin(0.1 * t)
  11. # 生成输入数据
  12. def generate_input_data(T):
  13. return np.random.uniform(-1, 1, T)
  14. # 系统模拟
  15. def simulate_system(u):
  16. x = 0 # 初始状态
  17. y = []
  18. for i in range(len(u)):
  19. x = A * x + B * u[i]
  20. y.append(C * x + D * u[i])
  21. return np.array(y)
  22. # DeePC算法实现
  23. def deepc_algorithm(T, Tini, N):
  24. # 生成输入数据
  25. ud = generate_input_data(T)
  26. # 模拟系统得到输出数据
  27. yd = simulate_system(ud)
  28. # 划分数据
  29. Up, Yp, Uf, Yf = partition_data(ud, yd, Tini, N)
  30. # 构建优化问题
  31. def objective_function(g):
  32. u = Uf @ g
  33. y = Yf @ g
  34. cost = np.sum((y - reference_trajectory(np.arange(Tini, Tini + N))) ** 2) + 0.1 * np.sum(u ** 2)
  35. return cost
  36. # 等式约束
  37. def equality_constraint(g):
  38. return np.concatenate((Up @ g - ud[:Tini], Yp @ g - yd[:Tini], Uf @ g, Yf @ g))
  39. # 初始猜测
  40. g0 = np.zeros((T - Tini - N + 1, 1))
  41. # 约束条件
  42. constraints = {'type': 'eq', 'fun': equality_constraint}
  43. # 求解优化问题
  44. result = minimize(objective_function, g0, constraints=constraints)
  45. # 最优输入序列
  46. u_opt = Uf @ result.x
  47. return u_opt
  48. # 划分数据函数
  49. def partition_data(ud, yd, Tini, N):
  50. Up = np.zeros((Tini, T - Tini - N + 1))
  51. Yp = np.zeros((Tini, T - Tini - N + 1))
  52. Uf = np.zeros((N, T - Tini - N + 1))
  53. Yf = np.zeros((N, T - Tini - N + 1))
  54. for i in range(T - Tini - N + 1):
  55. Up[:, i] = ud[i:i + Tini]
  56. Yp[:, i] = yd[i:i + Tini]
  57. Uf[:, i] = ud[i + Tini:i + Tini + N]
  58. Yf[:, i] = yd[i + Tini:i + Tini + N]
  59. return Up, Yp, Uf, Yf
  60. # 测试
  61. T = 100
  62. Tini = 10
  63. N = 5
  64. u_opt = deepc_algorithm(T, Tini, N)
  65. print(u_opt)