import numpy as np import matplotlib.pyplot as plt from scipy.optimize import least_squares class STEPS: """实现子信号技术参数估计算法(STEPS)""" def __init__(self, fs, N0): """ 初始化STEPS估计器 :param fs: 采样频率 :param N0: 原始信号长度 """ self.fs = fs self.N0 = N0 # 子信号长度选择为原始信号的2/3(论文推荐的最优比例) self.N1 = int(2 * N0 / 3) self.N2 = self.N1 # 子信号起始索引 self.n1 = 0 self.n2 = N0 - self.N2 # 确保子信号长度有效 if self.N1 <= 0 or self.N2 <= 0: raise ValueError("子信号长度必须为正数,请增加原始信号长度") # 预计算频率分辨率 self.freq_res = fs / N0 def extract_subsignals(self, x): """ 从原始信号中提取两个子信号 :param x: 原始信号 :return: 两个子信号x1和x2 """ if len(x) != self.N0: raise ValueError(f"输入信号长度必须为{N0},实际为{len(x)}") x1 = x[self.n1: self.n1 + self.N1] x2 = x[self.n2: self.n2 + self.N2] return x1, x2 def compute_dft(self, x): """计算信号的DFT并返回幅度和相位""" X = np.fft.fft(x) amp = np.abs(X) phase = np.angle(X) return amp, phase def find_peak_bin(self, amp): """找到DFT幅度谱中的峰值位置""" return np.argmax(amp) def phase_correction(self, phase, k, N): """ 相位校正,消除线性相位影响 :param phase: 原始相位 :param k: 峰值所在的频点 :param N: 信号长度 :return: 校正后的相位 """ return phase - 2 * np.pi * k * (N - 1) / (2 * N) def objective_function(self, delta1, k1, k2, c1, c2, phi1, phi2): """用于求解非线性方程的目标函数""" # 计算方程左边 left_numerator = c1 * np.tan(phi2) - c2 * np.tan(phi1) left_denominator = c1 * c2 + np.tan(phi1) * np.tan(phi2) left = left_numerator / left_denominator # 计算方程右边 term = np.pi * (k1 + delta1) / self.N1 * (2 * (self.n2 - self.n1) + self.N2 - self.N1) right = np.tan(term) # 返回误差 return left - right def estimate_frequency(self, x1, x2): """ 估计信号频率 :param x1, x2: 两个子信号 :return: 估计的频率 """ # 计算子信号的DFT amp1, phase1 = self.compute_dft(x1) amp2, phase2 = self.compute_dft(x2) # 找到峰值位置 k1 = self.find_peak_bin(amp1) k2 = self.find_peak_bin(amp2) # 相位校正 phi1 = self.phase_correction(phase1[k1], k1, self.N1) phi2 = self.phase_correction(phase2[k2], k2, self.N2) # 计算c1和c2参数 c1 = np.sin(np.pi * k1 / self.N1) / np.sin(np.pi * (k1 + 1) / self.N1) c2 = np.sin(np.pi * k2 / self.N2) / np.sin(np.pi * (k2 + 1) / self.N2) # 求解非线性方程找到delta1 def func(delta): return self.objective_function(delta, k1, k2, c1, c2, phi1, phi2) # 使用最小二乘法求解 result = least_squares(func, x0=0.5, bounds=(0, 1)) delta1 = result.x[0] # 计算频率 l0 = (self.N0 / self.N1) * (k1 + delta1) f0 = (l0 / self.N0) * self.fs return f0, k1, amp1[k1] def estimate_amplitude(self, k1, amp1_peak): """估计信号振幅""" # 计算振幅校正因子 delta1 = (self.N1 / self.N0) * (self.fs / self.freq_res) - k1 correction = np.abs(np.sin(np.pi * delta1) / (self.N1 * np.sin(np.pi * (k1 + delta1) / self.N1))) amplitude = amp1_peak * correction return amplitude def estimate_phase(self, f0, k1, amp1_peak): """估计信号初始相位""" delta1 = (self.N1 * f0 / self.fs) - k1 phase = np.angle(amp1_peak) - np.pi * delta1 * (self.N1 - 1) / self.N1 # 将相位归一化到[-π, π]范围 return (phase + np.pi) % (2 * np.pi) - np.pi def estimate(self, x): """ 估计正弦信号的参数 :param x: 输入信号 :return: 频率、振幅和相位的估计值 """ # 提取子信号 x1, x2 = self.extract_subsignals(x) # 估计频率 f0, k1, amp1_peak = self.estimate_frequency(x1, x2) # 估计振幅 amp = self.estimate_amplitude(k1, amp1_peak) # 估计相位 phase = self.estimate_phase(f0, k1, amp1_peak) return f0, amp, phase # 演示如何使用STEPS类 if __name__ == "__main__": # 生成测试信号 fs = 1000 # 采样频率 N0 = 1024 # 信号长度 f_true = 75.3 # 真实频率 amp_true = 2.5 # 真实振幅 phase_true = np.pi / 3 # 真实相位 snr_db = 40 # 信噪比(dB) # 生成时间向量 t = np.arange(N0) / fs # 生成干净信号 x_clean = amp_true * np.sin(2 * np.pi * f_true * t + phase_true) # 添加噪声 snr = 10 ** (snr_db / 10) signal_power = np.sum(x_clean ** 2) / N0 noise_power = signal_power / snr noise = np.sqrt(noise_power) * np.random.randn(N0) x = x_clean + noise # 创建STEPS估计器并进行参数估计 steps = STEPS(fs, N0) f_est, amp_est, phase_est = steps.estimate(x) # 计算误差 f_error = np.abs(f_est - f_true) amp_error = np.abs(amp_est - amp_true) phase_error = np.abs(phase_est - phase_true) # 显示结果 print(f"真实频率: {f_true:.4f} Hz, 估计频率: {f_est:.4f} Hz, 误差: {f_error:.6f} Hz") print(f"真实振幅: {amp_true:.4f}, 估计振幅: {amp_est:.4f}, 误差: {amp_error:.6f}") print(f"真实相位: {phase_true:.4f} rad, 估计相位: {phase_est:.4f} rad, 误差: {phase_error:.6f} rad") # 绘制信号和频谱 plt.figure(figsize=(12, 8)) # 设置中文字体支持 plt.rcParams["font.family"] = ["SimHei"] plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题 # 绘制时域信号 plt.subplot(2, 1, 1) plt.plot(t, x_clean, label='干净信号') plt.plot(t, x, alpha=0.7, label='带噪声信号') plt.xlabel('时间 (s)') plt.ylabel('振幅') plt.title('时域信号') plt.legend() # 绘制频谱 plt.subplot(2, 1, 2) freq = np.fft.fftfreq(N0, 1 / fs) x_fft = np.fft.fft(x) plt.plot(freq[:N0 // 2], 20 * np.log10(np.abs(x_fft[:N0 // 2])), label='信号频谱') plt.xlabel('频率 (Hz)') plt.ylabel('幅度 (dB)') plt.title('信号频谱') plt.xlim(0, 150) # 限制频率范围以便观察 plt.axvline(f_true, color='r', linestyle='--', label=f'真实频率: {f_true} Hz') plt.axvline(f_est, color='g', linestyle='--', label=f'估计频率: {f_est:.2f} Hz') plt.legend() plt.tight_layout() plt.show()