test_STEPS.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from scipy.optimize import least_squares
  4. class STEPS:
  5. """实现子信号技术参数估计算法(STEPS)"""
  6. def __init__(self, fs, N0):
  7. """
  8. 初始化STEPS估计器
  9. :param fs: 采样频率
  10. :param N0: 原始信号长度
  11. """
  12. self.fs = fs
  13. self.N0 = N0
  14. # 子信号长度选择为原始信号的2/3(论文推荐的最优比例)
  15. self.N1 = int(2 * N0 / 3)
  16. self.N2 = self.N1
  17. # 子信号起始索引
  18. self.n1 = 0
  19. self.n2 = N0 - self.N2
  20. # 确保子信号长度有效
  21. if self.N1 <= 0 or self.N2 <= 0:
  22. raise ValueError("子信号长度必须为正数,请增加原始信号长度")
  23. # 预计算频率分辨率
  24. self.freq_res = fs / N0
  25. def extract_subsignals(self, x):
  26. """
  27. 从原始信号中提取两个子信号
  28. :param x: 原始信号
  29. :return: 两个子信号x1和x2
  30. """
  31. if len(x) != self.N0:
  32. raise ValueError(f"输入信号长度必须为{N0},实际为{len(x)}")
  33. x1 = x[self.n1: self.n1 + self.N1]
  34. x2 = x[self.n2: self.n2 + self.N2]
  35. return x1, x2
  36. def compute_dft(self, x):
  37. """计算信号的DFT并返回幅度和相位"""
  38. X = np.fft.fft(x)
  39. amp = np.abs(X)
  40. phase = np.angle(X)
  41. return amp, phase
  42. def find_peak_bin(self, amp):
  43. """找到DFT幅度谱中的峰值位置"""
  44. return np.argmax(amp)
  45. def phase_correction(self, phase, k, N):
  46. """
  47. 相位校正,消除线性相位影响
  48. :param phase: 原始相位
  49. :param k: 峰值所在的频点
  50. :param N: 信号长度
  51. :return: 校正后的相位
  52. """
  53. return phase - 2 * np.pi * k * (N - 1) / (2 * N)
  54. def objective_function(self, delta1, k1, k2, c1, c2, phi1, phi2):
  55. """用于求解非线性方程的目标函数"""
  56. # 计算方程左边
  57. left_numerator = c1 * np.tan(phi2) - c2 * np.tan(phi1)
  58. left_denominator = c1 * c2 + np.tan(phi1) * np.tan(phi2)
  59. left = left_numerator / left_denominator
  60. # 计算方程右边
  61. term = np.pi * (k1 + delta1) / self.N1 * (2 * (self.n2 - self.n1) + self.N2 - self.N1)
  62. right = np.tan(term)
  63. # 返回误差
  64. return left - right
  65. def estimate_frequency(self, x1, x2):
  66. """
  67. 估计信号频率
  68. :param x1, x2: 两个子信号
  69. :return: 估计的频率
  70. """
  71. # 计算子信号的DFT
  72. amp1, phase1 = self.compute_dft(x1)
  73. amp2, phase2 = self.compute_dft(x2)
  74. # 找到峰值位置
  75. k1 = self.find_peak_bin(amp1)
  76. k2 = self.find_peak_bin(amp2)
  77. # 相位校正
  78. phi1 = self.phase_correction(phase1[k1], k1, self.N1)
  79. phi2 = self.phase_correction(phase2[k2], k2, self.N2)
  80. # 计算c1和c2参数
  81. c1 = np.sin(np.pi * k1 / self.N1) / np.sin(np.pi * (k1 + 1) / self.N1)
  82. c2 = np.sin(np.pi * k2 / self.N2) / np.sin(np.pi * (k2 + 1) / self.N2)
  83. # 求解非线性方程找到delta1
  84. def func(delta):
  85. return self.objective_function(delta, k1, k2, c1, c2, phi1, phi2)
  86. # 使用最小二乘法求解
  87. result = least_squares(func, x0=0.5, bounds=(0, 1))
  88. delta1 = result.x[0]
  89. # 计算频率
  90. l0 = (self.N0 / self.N1) * (k1 + delta1)
  91. f0 = (l0 / self.N0) * self.fs
  92. return f0, k1, amp1[k1]
  93. def estimate_amplitude(self, k1, amp1_peak):
  94. """估计信号振幅"""
  95. # 计算振幅校正因子
  96. delta1 = (self.N1 / self.N0) * (self.fs / self.freq_res) - k1
  97. correction = np.abs(np.sin(np.pi * delta1) / (self.N1 * np.sin(np.pi * (k1 + delta1) / self.N1)))
  98. amplitude = amp1_peak * correction
  99. return amplitude
  100. def estimate_phase(self, f0, k1, amp1_peak):
  101. """估计信号初始相位"""
  102. delta1 = (self.N1 * f0 / self.fs) - k1
  103. phase = np.angle(amp1_peak) - np.pi * delta1 * (self.N1 - 1) / self.N1
  104. # 将相位归一化到[-π, π]范围
  105. return (phase + np.pi) % (2 * np.pi) - np.pi
  106. def estimate(self, x):
  107. """
  108. 估计正弦信号的参数
  109. :param x: 输入信号
  110. :return: 频率、振幅和相位的估计值
  111. """
  112. # 提取子信号
  113. x1, x2 = self.extract_subsignals(x)
  114. # 估计频率
  115. f0, k1, amp1_peak = self.estimate_frequency(x1, x2)
  116. # 估计振幅
  117. amp = self.estimate_amplitude(k1, amp1_peak)
  118. # 估计相位
  119. phase = self.estimate_phase(f0, k1, amp1_peak)
  120. return f0, amp, phase
  121. # 演示如何使用STEPS类
  122. if __name__ == "__main__":
  123. # 生成测试信号
  124. fs = 1000 # 采样频率
  125. N0 = 1024 # 信号长度
  126. f_true = 75.3 # 真实频率
  127. amp_true = 2.5 # 真实振幅
  128. phase_true = np.pi / 3 # 真实相位
  129. snr_db = 40 # 信噪比(dB)
  130. # 生成时间向量
  131. t = np.arange(N0) / fs
  132. # 生成干净信号
  133. x_clean = amp_true * np.sin(2 * np.pi * f_true * t + phase_true)
  134. # 添加噪声
  135. snr = 10 ** (snr_db / 10)
  136. signal_power = np.sum(x_clean ** 2) / N0
  137. noise_power = signal_power / snr
  138. noise = np.sqrt(noise_power) * np.random.randn(N0)
  139. x = x_clean + noise
  140. # 创建STEPS估计器并进行参数估计
  141. steps = STEPS(fs, N0)
  142. f_est, amp_est, phase_est = steps.estimate(x)
  143. # 计算误差
  144. f_error = np.abs(f_est - f_true)
  145. amp_error = np.abs(amp_est - amp_true)
  146. phase_error = np.abs(phase_est - phase_true)
  147. # 显示结果
  148. print(f"真实频率: {f_true:.4f} Hz, 估计频率: {f_est:.4f} Hz, 误差: {f_error:.6f} Hz")
  149. print(f"真实振幅: {amp_true:.4f}, 估计振幅: {amp_est:.4f}, 误差: {amp_error:.6f}")
  150. print(f"真实相位: {phase_true:.4f} rad, 估计相位: {phase_est:.4f} rad, 误差: {phase_error:.6f} rad")
  151. # 绘制信号和频谱
  152. plt.figure(figsize=(12, 8))
  153. # 设置中文字体支持
  154. plt.rcParams["font.family"] = ["SimHei"]
  155. plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
  156. # 绘制时域信号
  157. plt.subplot(2, 1, 1)
  158. plt.plot(t, x_clean, label='干净信号')
  159. plt.plot(t, x, alpha=0.7, label='带噪声信号')
  160. plt.xlabel('时间 (s)')
  161. plt.ylabel('振幅')
  162. plt.title('时域信号')
  163. plt.legend()
  164. # 绘制频谱
  165. plt.subplot(2, 1, 2)
  166. freq = np.fft.fftfreq(N0, 1 / fs)
  167. x_fft = np.fft.fft(x)
  168. plt.plot(freq[:N0 // 2], 20 * np.log10(np.abs(x_fft[:N0 // 2])), label='信号频谱')
  169. plt.xlabel('频率 (Hz)')
  170. plt.ylabel('幅度 (dB)')
  171. plt.title('信号频谱')
  172. plt.xlim(0, 150) # 限制频率范围以便观察
  173. plt.axvline(f_true, color='r', linestyle='--', label=f'真实频率: {f_true} Hz')
  174. plt.axvline(f_est, color='g', linestyle='--', label=f'估计频率: {f_est:.2f} Hz')
  175. plt.legend()
  176. plt.tight_layout()
  177. plt.show()