test_RGSGQF_filter.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import numpy as np
  2. # 定义权重函数
  3. def weight_function(zeta, k1=1.5, k2=5):
  4. abs_zeta = np.abs(zeta)
  5. psi = np.zeros_like(abs_zeta)
  6. psi[abs_zeta <= k1] = 1
  7. mask = (abs_zeta > k1) & (abs_zeta <= k2)
  8. psi[mask] = (1 - ((abs_zeta[mask] - k1) ** 2 / (k2 - k1) ** 2)) ** 2
  9. # 确保权重不为零
  10. psi[psi < 1e-6] = 1e-6
  11. # print(f"psi={psi}")
  12. return psi
  13. # RESGQF滤波函数
  14. def resgqf_filter(x_hat, P, z, f, h, Q, R, k1=1.5, k2=5):
  15. # 状态预测
  16. x_pred = f(x_hat)
  17. P_pred = P + Q
  18. # 计算量测残差
  19. z_hat = h(x_pred)
  20. r = z - z_hat
  21. # 标准化残差
  22. R_inv_sqrt = np.linalg.inv(np.linalg.cholesky(R))
  23. zeta = R_inv_sqrt @ r
  24. # 计算权重矩阵
  25. psi = weight_function(zeta)
  26. Psi = np.diag(psi)
  27. # 确保 Psi 是一个 2D 矩阵
  28. if Psi.ndim == 1:
  29. Psi = np.diag(Psi)
  30. # # 添加正则化项以避免矩阵不可逆
  31. # epsilon = 1e-6 # 小的正则化项
  32. # Psi += epsilon * np.eye(Psi.shape[0])
  33. # 更新量测噪声协方差
  34. R_bar = R_inv_sqrt @ np.linalg.inv(Psi) @ R_inv_sqrt.T
  35. # 计算跨协方差 P_xz 和自协方差 P_zz
  36. P_xz = P_pred[0, 0] * z_hat[0]
  37. P_zz = z_hat[0]**2 + R_bar[0, 0]
  38. # 确保 P_zz 不为零以避免除以零错误
  39. if P_zz < 1e-6:
  40. P_zz = 1e-6
  41. # 计算卡尔曼增益
  42. K = P_xz / P_zz
  43. # 状态更新
  44. x_hat = x_pred + K * (z - z_hat[0])
  45. P = P_pred - K**2 * P_xz
  46. # 打印调试信息
  47. print(f"量测: {z[0]}, 预测状态: {x_pred[0]}, 量测残差: {r[0]}, 标准化残差: {zeta[0]}, 权重: {psi[0]}, 量测噪声协方差: {R_bar[0, 0]}, 卡尔曼增益: {K}, 估计状态: {x_hat[0]}, 协方差: {P[0, 0]}")
  48. return x_hat, P
  49. # 示例参数设置
  50. # 状态转移函数
  51. def f(x):
  52. return x # 简单的恒等转移
  53. # 量测函数
  54. def h(x):
  55. return x # 简单的恒等映射
  56. # 初始状态估计
  57. x_hat = np.array([0.1])
  58. # 初始协方差
  59. P = np.array([[1.0]])
  60. # 过程噪声协方差
  61. Q = np.array([[0.1]])
  62. # 量测噪声协方差
  63. R = np.array([[1.0]])
  64. # 模拟量测值,包含异常值
  65. measurements = [1.1, 1.2, 10.0, 1.3, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 10, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4] # 10.0 为异常值
  66. # 运行滤波
  67. for z in measurements:
  68. x_hat, P = resgqf_filter(x_hat, P, np.array([z]), f, h, Q, R)