Ver Fonte

添加新的测试代码

alex há 2 meses atrás
pai
commit
cb695c6304

+ 7 - 0
.idea/MarsCodeWorkspaceAppSettings.xml

@@ -0,0 +1,7 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="com.codeverse.userSettings.MarscodeWorkspaceAppSettingsState">
+    <option name="ckgOperationStatus" value="SUCCESS" />
+    <option name="progress" value="0.77272725" />
+  </component>
+</project>

+ 6 - 0
.idea/PythonProject.iml

@@ -7,4 +7,10 @@
     <orderEntry type="jdk" jdkName="Python 3.11 (PythonProject)" jdkType="Python SDK" />
     <orderEntry type="sourceFolder" forTests="false" />
   </component>
+  <component name="PackageRequirementsSettings">
+    <option name="requirementsPath" value="" />
+  </component>
+  <component name="TestRunnerService">
+    <option name="PROJECT_TEST_RUNNER" value="Unittests" />
+  </component>
 </module>

+ 4 - 0
.idea/encodings.xml

@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="Encoding" defaultCharsetForPropertiesFiles="UTF-8" />
+</project>

+ 3 - 0
1.TXT

@@ -0,0 +1,3 @@
+我将帮助你在
+testUI.py
+ 中实现点击第一个菜单下的第一个子菜单时读取 1.txt 文件,并将文件内容显示在主界面上。以下是完整的代码示例:

+ 10 - 0
MZ/.idea/MZ.iml

@@ -0,0 +1,10 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$">
+      <excludeFolder url="file://$MODULE_DIR$/venv" />
+    </content>
+    <orderEntry type="inheritedJdk" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+</module>

+ 77 - 0
example.py

@@ -0,0 +1,77 @@
+import tkinter as tk
+from tkhtmlview import HTMLScrolledText
+
+root = tk.Tk()
+root.geometry("780x640")
+html_label = HTMLScrolledText(
+    root,
+    html="""
+      <h1 style="color: red; text-align: center"> Hello World </h1>
+      <h2> Smaller </h2>
+      <h3> Little Smaller </h3>
+      <h4> Little Little Smaller </h4>
+      <h5> Little Little Little Smaller </h5>
+      <h6> Little Little Little Little Smaller </h6>
+      <b>This is a Bold text</b><br/>
+      <strong>This is an Important text</strong><br/>
+      <i>This is a Italic text</i><br/>
+      <em>This is a Emphasized text</em><br/>
+      <em>This is a <strong>Strong Emphasized</strong>   text   </em>.<br/>
+      <img src="https://www.google.com/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png"></img>
+      <ul>
+      <li> One </li>
+        <ul>
+            <li> One.Zero </li>
+            <li> One.One </li>
+        </ul>
+      <li> Two </li>
+      <li> Three </li>
+      <li> Four </li>
+      </ul>
+
+      <h3> Paragraph </h3>
+      <p>Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed ut quam sapien. Maecenas porta tempus mauris sed ullamcorper. Nulla facilisi. Nulla facilisi. Mauris tristique ipsum et efficitur lobortis. Sed pharetra ipsum non lacinia dignissim. Ut condimentum vulputate sem eget scelerisque. Curabitur ornare augue enim, sed volutpat enim finibus id. </p>
+
+      <h3>Table</h3
+      <table>
+          <tr>
+            <th style="background-color: silver">T Header 1</th>
+            <th style="background-color: silver">T Header 2</th>
+          </tr>
+          <tr>
+            <td>ABC</td>
+            <td><em>123</em></td>
+          </tr>
+          <tr>
+            <td>DEF</td>
+            <td><i>456</i></td>
+          </tr>
+      </table>
+
+      <h3> Preformatted text: spaces and newlines preserved </h3>
+      <pre style="font-family: Consolas; font-size: 80%">
+<b>Tags      Attributes          Notes</b>
+a         style, <span style="color: RoyalBlue">href         </span>-
+img       <u>src</u>, width, height  <i>experimental</i>
+ol        <span style="color: Tomato">style, type         </span>-
+ul        <i>style               </i>bullet glyphs only
+div       style               -
+      </pre>
+
+      <h3> Code: spaces and newlines ignored </h3>
+      <code style="font-family: Consolas; font-size: 80%">
+<b>Tags      Attributes          Notes</b><br/>
+a         style, <span style="color: RoyalBlue">href         </span>-
+img       <u>src</u>, width, height  <i>experimental</i>
+ol        <span style="color: Tomato">style, type         </span>-
+ul        <i>style               </i>bullet glyphs only
+div       style               -
+      </code>
+
+      <p>Inline <code>code blocks</code> displayed without newlines</p>
+
+    """,
+)
+html_label.pack(fill="both", expand=True)
+html_label.fit_height()
+root.mainloop()

+ 35 - 0
showhtml.py

@@ -0,0 +1,35 @@
+import os
+import tkinter as tk
+from tkinterweb import HtmlFrame
+
+
+def main():
+    root = tk.Tk()
+    root.title("设置基准路径示例")
+    root.geometry("800x600")
+
+    # 创建HTML浏览器组件
+    browser = HtmlFrame(root)
+    browser.pack(fill="both", expand=True)
+
+    # HTML文件路径(请替换为你的实际路径)
+    html_file = "data\\2.htm"
+
+    # 计算HTML文件所在的目录
+    html_dir = os.path.dirname(os.path.abspath(html_file))
+
+    # 读取HTML内容
+    with open(html_file, "r", encoding="utf-8") as file:
+        html_content = file.read()
+
+    # 设置基准路径(使用file://协议)
+    # base_uri = f"file://{html_dir}/"
+
+    # 加载HTML内容并设置基准路径
+    browser.add_html(html_content)
+
+    root.mainloop()
+
+
+if __name__ == "__main__":
+    main()

+ 135 - 0
testUI.py

@@ -0,0 +1,135 @@
+import ttkbootstrap as ttk
+import os
+import tkinter as tk
+from tkhtmlview import HTMLLabel
+# from tkinterweb import HtmlFrame
+
+
+def on_menu_item_click(label):
+    if label == "Item 1":
+        display_html_file("data/1.html")
+    elif label == "Item 2":
+        display_html_file("data/2.html")
+    elif label == "Sub Item 1":
+        display_text_and_entries()
+    elif label == "about":
+        display_html_file("data/about.html")
+    else:
+        print(f"Clicked on {label}")
+
+def display_html_file(filepath):
+    # 清空当前窗口内容,但保留菜单栏
+    for widget in window.winfo_children():
+        if not isinstance(widget, ttk.Menu):
+            widget.destroy()
+
+    if os.path.exists(filepath):
+        with open(filepath, 'r', encoding='utf-8') as file:
+            html_content = file.read()
+
+        # 创建 HTMLLabel 来显示 HTML 内容
+        global html_label
+        html_label = HTMLLabel(window, html=html_content,wrap=ttk.CHAR)
+        # tkinterweb
+        # html_label = HtmlFrame(window)
+        #
+        # html_label.load_file("C:\\Users\\wang\\PycharmProjects\\PythonProject\\data\\1.html")
+
+        # html_label.add_html(html_content)
+        #html_label.load_html(html_content,base_url="C:\\Users\\wang\\PycharmProjects\\PythonProject\\data")
+
+
+        html_label.pack(side=ttk.TOP, fill=ttk.BOTH, expand=True)
+        html_label.fit_height()
+    else:
+        # 创建文本区域用于显示错误信息
+        global text_area
+        text_area = ttk.Text(window, wrap=ttk.WORD, font=('Arial', 12))
+        text_area.pack(side=ttk.TOP, fill=ttk.BOTH, expand=True)
+        text_area.insert(ttk.END, f"File {filepath} not found.")
+
+def display_text_and_entries():
+    # 清空当前窗口内容,但保留菜单栏
+    for widget in window.winfo_children():
+        if not isinstance(widget, ttk.Menu):
+            widget.destroy()
+
+    global entry_list
+    entry_list = []
+    global label_list
+    label_list = []
+
+    # 创建10行界面
+    for i in range(1, 11):
+        frame = ttk.Frame(window)
+        frame.pack(pady=5, fill=ttk.X)
+
+        title_label = ttk.Label(frame, text=f"Title {i}", font=('Arial', 12))
+        title_label.pack(side=ttk.LEFT, padx=5)
+
+        entry = ttk.Entry(frame, font=('Arial', 12), width=20)
+        entry.pack(side=ttk.LEFT, padx=5)
+
+        entry_list.append(entry)
+
+        # 添加一个新的Label来显示内容
+        label = ttk.Label(frame, text=f"text {i}", font=('Arial', 12))
+        label.pack(side=ttk.LEFT, padx=5)
+
+        label_list.append(label)
+
+    # 创建按钮
+    fill_button = ttk.Button(window, text="Fill All with Hello", command=fill_entries_with_hello)
+    fill_button.pack(pady=10)
+
+def fill_entries_with_hello():
+    for entry in entry_list:
+        entry.delete(0, ttk.END)
+        entry.insert(0, "Hello")
+
+def display_text_area():
+    # 清空当前窗口内容,但保留菜单栏
+    for widget in window.winfo_children():
+        if not isinstance(widget, ttk.Menu):
+            widget.destroy()
+
+    global text_area
+    text_area = ttk.Text(window, wrap=ttk.WORD, font=('Arial', 12))
+    text_area.pack(side=ttk.TOP, fill=ttk.BOTH, expand=True)
+
+# 创建窗口并设置主题为 'flatly'
+window = ttk.Window(themename='flatly')
+
+# 设置窗口大小为 2000x1200
+window.geometry("2000x1200")
+
+# 初始显示 text_area
+display_text_area()
+
+# 创建菜单栏
+menu_bar = ttk.Menu(window)
+
+# 创建第一个菜单及其子菜单
+first_menu = ttk.Menu(menu_bar, tearoff=0)
+for i in range(1, 11):
+    menu_label = f"Item {i}"
+    first_menu.add_command(label=menu_label, command=lambda label=menu_label: on_menu_item_click(label))
+menu_bar.add_cascade(label="Menu 1", menu=first_menu)
+
+# 创建第二个菜单及其子菜单
+second_menu = ttk.Menu(menu_bar, tearoff=0)
+second_menu.add_command(label="Sub Item 1", command=lambda: on_menu_item_click("Sub Item 1"))
+second_menu.add_command(label="Sub Item 2", command=lambda: on_menu_item_click("Sub Item 2"))
+menu_bar.add_cascade(label="Menu 2", menu=second_menu)
+
+# 创建第三个菜单及其子菜单
+third_menu = ttk.Menu(menu_bar, tearoff=0)
+third_menu.add_command(label="关于本软件", command=lambda: on_menu_item_click("about"))
+#third_menu.add_command(label="Sub Item 2", command=lambda: on_menu_item_click("Sub Item 2"))
+menu_bar.add_cascade(label="关于", menu=third_menu)
+
+# 设置窗口的菜单栏
+window.config(menu=menu_bar)
+
+# 运行窗口主循环
+window.mainloop()

+ 231 - 0
test_3D.py

@@ -0,0 +1,231 @@
+import vtk
+
+class ButtonCallback(object):
+    def __init__(self, renderer, actor1):
+        self.renderer = renderer
+        self.actor1 = actor1
+
+    def execute(self, obj, event):
+        print("Button was clicked!")
+        # Move actor1
+        if self.actor1:
+            current_position = self.actor1.GetPosition()
+            self.actor1.SetPosition(current_position[0] + 0.1, current_position[1], current_position[2])
+            self.renderer.GetRenderWindow().Render()
+
+class MyInteractorStyle(vtk.vtkInteractorStyleTrackballCamera):
+    def __init__(self, parent=None):
+        super().__init__()
+        self.AddObserver("LeftButtonPressEvent", self.left_button_press)
+        self.actor1 = None
+        self.actor2 = None
+        self.label1 = None
+        self.label2 = None
+
+    def left_button_press(self, obj, event):
+        try:
+            click_pos = self.GetInteractor().GetEventPosition()
+            print(f'left_button_press({click_pos[0]}, {click_pos[1]})')
+
+            # Get the default renderer
+            renderer = self.GetDefaultRenderer()
+            if renderer is None:
+                print("Error: Default renderer is None")
+                return
+
+            # Create a picker and perform picking
+            picker = vtk.vtkPropPicker()
+            picker.Pick(click_pos[0], click_pos[1], 0, renderer)
+
+            # Get the picked actor
+            picked_actor = picker.GetActor()
+            if picked_actor is None:
+                print("No actor was picked")
+            elif picked_actor == self.actor1:
+                print("Actor 1 was picked")
+                # Define the movement of actor2
+                self.actor2.SetPosition(self.actor2.GetPosition()[0] + 0.1,
+                                       self.actor2.GetPosition()[1],
+                                       self.actor2.GetPosition()[2])
+            elif picked_actor == self.actor2:
+                print("Actor 2 was picked")
+                # Define the movement of actor1
+                self.actor1.SetPosition(self.actor1.GetPosition()[0] + 0.1,
+                                       self.actor1.GetPosition()[1],
+                                       self.actor1.GetPosition()[2])
+            else:
+                print("Some other actor was picked")
+
+            self.GetInteractor().Render()
+        except Exception as e:
+            print(f"An error occurred: {e}")
+
+def create_button(renderer, callback, position, size, label):
+    # Create a TextActor to display the button text
+    text_actor = vtk.vtkTextActor()
+    text_property = vtk.vtkTextProperty()
+    text_property.SetFontFamilyToArial()
+    text_property.SetFontSize(12)
+    text_property.BoldOn()
+    text_property.ItalicOff()
+    text_property.ShadowOn()
+    text_property.SetJustificationToCentered()
+    text_property.SetColor(0, 0, 0)  # Text color
+    text_actor.SetInput(label)
+    text_actor.SetTextProperty(text_property)
+
+    # Calculate text position, center it
+    text_actor.SetDisplayPosition(position[0] + size[0] // 2, position[1] + size[1] // 2 - 10)
+
+    # Create an ImageData to create the background image
+    width, height = size
+    image_data = vtk.vtkImageData()
+    image_data.SetDimensions(width, height, 1)
+    image_data.SetSpacing(1, 1, 1)
+    image_data.SetOrigin(0, 0, 0)
+    image_data.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, 4)
+
+    colors = vtk.vtkNamedColors()
+    background_color = colors.GetColor4ub("LightGray")
+    highlight_color = colors.GetColor4ub("LightBlue")
+
+    # Set the default background color
+    scalars = image_data.GetPointData().GetScalars()
+    for y in range(height):
+        for x in range(width):
+            offset = y * width + x
+            scalars.SetTuple4(offset, background_color[0], background_color[1], background_color[2], background_color[3])
+
+    # Create an ImageMapper to create the background image
+    image_mapper = vtk.vtkImageMapper()
+    image_mapper.SetInputData(image_data)
+
+    # Create an ImageActor to display the background image
+    background_actor = vtk.vtkImageActor()
+    background_actor.SetMapper(image_mapper)
+
+    # Set background position
+    background_actor.SetDisplayPosition(position[0], position[1])
+
+    # Set interaction representation
+    button_rep = vtk.vtkButtonRepresentation2D()
+    button_rep.SetPlaceFactor(1.0)
+    button_rep.PlaceWidget(position, [size[0], size[1]])
+
+    # Create a ButtonWidget to manage button interaction
+    button = vtk.vtkButtonWidget()
+    button.SetInteractor(renderer.GetRenderWindow().GetInteractor())
+    button.SetRepresentation(button_rep)
+    button.AddObserver("StateChangedEvent", callback)
+
+    # Enable the button
+    button.On()
+
+    # Update background color to reflect button state
+    def update_button_state(obj, event):
+        state = button_rep.GetState()
+        if state == 0:  # Normal state
+            color = background_color
+        else:  # Highlight state
+            color = highlight_color
+
+        scalars = image_data.GetPointData().GetScalars()
+        for y in range(height):
+            for x in range(width):
+                offset = y * width + x
+                scalars.SetTuple4(offset, color[0], color[1], color[2], color[3])
+
+        image_data.Modified()
+        renderer.GetRenderWindow().Render()
+
+    button.AddObserver("StateChangedEvent", update_button_state)
+
+    # Add text and background to renderer
+    renderer.AddActor2D(text_actor)
+    renderer.AddActor2D(background_actor)
+
+    return button
+
+def read_and_display_stl(file_path1, file_path2):
+    # Read the first STL model
+    reader1 = vtk.vtkSTLReader()
+    reader1.SetFileName(file_path1)
+
+    mapper1 = vtk.vtkPolyDataMapper()
+    mapper1.SetInputConnection(reader1.GetOutputPort())
+
+    actor1 = vtk.vtkActor()
+    actor1.SetMapper(mapper1)
+
+    # Create label
+    text_property1 = vtk.vtkTextProperty()
+    text_property1.SetFontFamilyToArial()
+    text_property1.SetFontSize(16)
+    text_property1.BoldOn()
+    text_property1.ItalicOff()
+    text_property1.ShadowOn()
+
+    label1 = vtk.vtkTextActor()
+    label1.SetInput("Model 1")
+    label1.SetTextProperty(text_property1)
+    label1.SetDisplayPosition(50, 50)  # Set position
+
+    # Read the second STL model
+    reader2 = vtk.vtkSTLReader()
+    reader2.SetFileName(file_path2)
+
+    mapper2 = vtk.vtkPolyDataMapper()
+    mapper2.SetInputConnection(reader2.GetOutputPort())
+
+    actor2 = vtk.vtkActor()
+    actor2.SetMapper(mapper2)
+
+    # Create label
+    text_property2 = vtk.vtkTextProperty()
+    text_property2.SetFontFamilyToArial()
+    text_property2.SetFontSize(16)
+    text_property2.BoldOn()
+    text_property2.ItalicOff()
+    text_property2.ShadowOn()
+
+    label2 = vtk.vtkTextActor()
+    label2.SetInput("Model 2")
+    label2.SetTextProperty(text_property2)
+    label2.SetDisplayPosition(50, 100)  # Set position
+
+    renderer = vtk.vtkRenderer()
+    renderer.AddActor(actor1)
+    renderer.AddActor(actor2)
+    renderer.AddActor2D(label1)  # Add label to renderer
+    renderer.AddActor2D(label2)  # Add label to renderer
+
+    render_window = vtk.vtkRenderWindow()
+    render_window.AddRenderer(renderer)
+    render_window.SetSize(800, 600)  # Set window size to 800x600
+
+    interactor = vtk.vtkRenderWindowInteractor()
+    interactor.SetRenderWindow(render_window)
+
+    # Set custom interaction style
+    style = MyInteractorStyle()
+    style.actor1 = actor1
+    style.actor2 = actor2
+    style.label1 = label1
+    style.label2 = label2
+    interactor.SetInteractorStyle(style)
+
+    # Set default renderer
+    style.SetDefaultRenderer(renderer)
+
+    # Create button
+    callback = ButtonCallback(renderer, actor1)
+    button1 = create_button(renderer, callback.execute, (100, 550), (100, 30), "Move Actor 1")
+
+    interactor.Initialize()
+    interactor.Start()
+
+if __name__ == "__main__":
+    # Replace with your two STL file paths
+    stl_file_path1 = "tetrahedron.stl"
+    stl_file_path2 = "tetrahedron.stl"
+    read_and_display_stl(stl_file_path1, stl_file_path2)

+ 87 - 0
test_RGSGQF_filter.py

@@ -0,0 +1,87 @@
+import numpy as np
+
+# 定义权重函数
+def weight_function(zeta, k1=1.5, k2=5):
+    abs_zeta = np.abs(zeta)
+    psi = np.zeros_like(abs_zeta)
+    psi[abs_zeta <= k1] = 1
+    mask = (abs_zeta > k1) & (abs_zeta <= k2)
+    psi[mask] = (1 - ((abs_zeta[mask] - k1) ** 2 / (k2 - k1) ** 2)) ** 2
+    # 确保权重不为零
+    psi[psi < 1e-6] = 1e-6
+    # print(f"psi={psi}")
+    return psi
+
+# RESGQF滤波函数
+def resgqf_filter(x_hat, P, z, f, h, Q, R, k1=1.5, k2=5):
+    # 状态预测
+    x_pred = f(x_hat)
+    P_pred = P + Q
+
+    # 计算量测残差
+    z_hat = h(x_pred)
+    r = z - z_hat
+    # 标准化残差
+    R_inv_sqrt = np.linalg.inv(np.linalg.cholesky(R))
+    zeta = R_inv_sqrt @ r
+
+    # 计算权重矩阵
+    psi = weight_function(zeta)
+    Psi = np.diag(psi)
+
+    # 确保 Psi 是一个 2D 矩阵
+    if Psi.ndim == 1:
+        Psi = np.diag(Psi)
+
+    # # 添加正则化项以避免矩阵不可逆
+    # epsilon = 1e-6  # 小的正则化项
+    # Psi += epsilon * np.eye(Psi.shape[0])
+
+    # 更新量测噪声协方差
+    R_bar = R_inv_sqrt @ np.linalg.inv(Psi) @ R_inv_sqrt.T
+
+    # 计算跨协方差 P_xz 和自协方差 P_zz
+    P_xz = P_pred[0, 0] * z_hat[0]
+    P_zz = z_hat[0]**2 + R_bar[0, 0]
+
+
+    # 确保 P_zz 不为零以避免除以零错误
+    if P_zz < 1e-6:
+        P_zz = 1e-6
+
+    # 计算卡尔曼增益
+    K = P_xz / P_zz
+
+    # 状态更新
+    x_hat = x_pred + K * (z - z_hat[0])
+    P = P_pred - K**2 * P_xz
+
+    # 打印调试信息
+    print(f"量测: {z[0]}, 预测状态: {x_pred[0]}, 量测残差: {r[0]}, 标准化残差: {zeta[0]}, 权重: {psi[0]}, 量测噪声协方差: {R_bar[0, 0]}, 卡尔曼增益: {K}, 估计状态: {x_hat[0]}, 协方差: {P[0, 0]}")
+
+    return x_hat, P
+
+# 示例参数设置
+# 状态转移函数
+def f(x):
+    return x  # 简单的恒等转移
+
+# 量测函数
+def h(x):
+    return x  # 简单的恒等映射
+
+# 初始状态估计
+x_hat = np.array([0.1])
+# 初始协方差
+P = np.array([[1.0]])
+# 过程噪声协方差
+Q = np.array([[0.1]])
+# 量测噪声协方差
+R = np.array([[1.0]])
+
+# 模拟量测值,包含异常值
+measurements = [1.1, 1.2, 10.0, 1.3, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 10, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4]  # 10.0 为异常值
+
+# 运行滤波
+for z in measurements:
+    x_hat, P = resgqf_filter(x_hat, P, np.array([z]), f, h, Q, R)

+ 218 - 0
test_STEPS.py

@@ -0,0 +1,218 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from scipy.optimize import least_squares
+
+
+class STEPS:
+    """实现子信号技术参数估计算法(STEPS)"""
+
+    def __init__(self, fs, N0):
+        """
+        初始化STEPS估计器
+        :param fs: 采样频率
+        :param N0: 原始信号长度
+        """
+        self.fs = fs
+        self.N0 = N0
+        # 子信号长度选择为原始信号的2/3(论文推荐的最优比例)
+        self.N1 = int(2 * N0 / 3)
+        self.N2 = self.N1
+        # 子信号起始索引
+        self.n1 = 0
+        self.n2 = N0 - self.N2
+
+        # 确保子信号长度有效
+        if self.N1 <= 0 or self.N2 <= 0:
+            raise ValueError("子信号长度必须为正数,请增加原始信号长度")
+
+        # 预计算频率分辨率
+        self.freq_res = fs / N0
+
+    def extract_subsignals(self, x):
+        """
+        从原始信号中提取两个子信号
+        :param x: 原始信号
+        :return: 两个子信号x1和x2
+        """
+        if len(x) != self.N0:
+            raise ValueError(f"输入信号长度必须为{N0},实际为{len(x)}")
+
+        x1 = x[self.n1: self.n1 + self.N1]
+        x2 = x[self.n2: self.n2 + self.N2]
+        return x1, x2
+
+    def compute_dft(self, x):
+        """计算信号的DFT并返回幅度和相位"""
+        X = np.fft.fft(x)
+        amp = np.abs(X)
+        phase = np.angle(X)
+        return amp, phase
+
+    def find_peak_bin(self, amp):
+        """找到DFT幅度谱中的峰值位置"""
+        return np.argmax(amp)
+
+    def phase_correction(self, phase, k, N):
+        """
+        相位校正,消除线性相位影响
+        :param phase: 原始相位
+        :param k: 峰值所在的频点
+        :param N: 信号长度
+        :return: 校正后的相位
+        """
+        return phase - 2 * np.pi * k * (N - 1) / (2 * N)
+
+    def objective_function(self, delta1, k1, k2, c1, c2, phi1, phi2):
+        """用于求解非线性方程的目标函数"""
+        # 计算方程左边
+        left_numerator = c1 * np.tan(phi2) - c2 * np.tan(phi1)
+        left_denominator = c1 * c2 + np.tan(phi1) * np.tan(phi2)
+        left = left_numerator / left_denominator
+
+        # 计算方程右边
+        term = np.pi * (k1 + delta1) / self.N1 * (2 * (self.n2 - self.n1) + self.N2 - self.N1)
+        right = np.tan(term)
+
+        # 返回误差
+        return left - right
+
+    def estimate_frequency(self, x1, x2):
+        """
+        估计信号频率
+        :param x1, x2: 两个子信号
+        :return: 估计的频率
+        """
+        # 计算子信号的DFT
+        amp1, phase1 = self.compute_dft(x1)
+        amp2, phase2 = self.compute_dft(x2)
+
+        # 找到峰值位置
+        k1 = self.find_peak_bin(amp1)
+        k2 = self.find_peak_bin(amp2)
+
+        # 相位校正
+        phi1 = self.phase_correction(phase1[k1], k1, self.N1)
+        phi2 = self.phase_correction(phase2[k2], k2, self.N2)
+
+        # 计算c1和c2参数
+        c1 = np.sin(np.pi * k1 / self.N1) / np.sin(np.pi * (k1 + 1) / self.N1)
+        c2 = np.sin(np.pi * k2 / self.N2) / np.sin(np.pi * (k2 + 1) / self.N2)
+
+        # 求解非线性方程找到delta1
+        def func(delta):
+            return self.objective_function(delta, k1, k2, c1, c2, phi1, phi2)
+
+        # 使用最小二乘法求解
+        result = least_squares(func, x0=0.5, bounds=(0, 1))
+        delta1 = result.x[0]
+
+        # 计算频率
+        l0 = (self.N0 / self.N1) * (k1 + delta1)
+        f0 = (l0 / self.N0) * self.fs
+
+        return f0, k1, amp1[k1]
+
+    def estimate_amplitude(self, k1, amp1_peak):
+        """估计信号振幅"""
+        # 计算振幅校正因子
+        delta1 = (self.N1 / self.N0) * (self.fs / self.freq_res) - k1
+        correction = np.abs(np.sin(np.pi * delta1) / (self.N1 * np.sin(np.pi * (k1 + delta1) / self.N1)))
+        amplitude = amp1_peak * correction
+        return amplitude
+
+    def estimate_phase(self, f0, k1, amp1_peak):
+        """估计信号初始相位"""
+        delta1 = (self.N1 * f0 / self.fs) - k1
+        phase = np.angle(amp1_peak) - np.pi * delta1 * (self.N1 - 1) / self.N1
+        # 将相位归一化到[-π, π]范围
+        return (phase + np.pi) % (2 * np.pi) - np.pi
+
+    def estimate(self, x):
+        """
+        估计正弦信号的参数
+        :param x: 输入信号
+        :return: 频率、振幅和相位的估计值
+        """
+        # 提取子信号
+        x1, x2 = self.extract_subsignals(x)
+
+        # 估计频率
+        f0, k1, amp1_peak = self.estimate_frequency(x1, x2)
+
+        # 估计振幅
+        amp = self.estimate_amplitude(k1, amp1_peak)
+
+        # 估计相位
+        phase = self.estimate_phase(f0, k1, amp1_peak)
+
+        return f0, amp, phase
+
+
+# 演示如何使用STEPS类
+if __name__ == "__main__":
+    # 生成测试信号
+    fs = 1000  # 采样频率
+    N0 = 1024  # 信号长度
+    f_true = 75.3  # 真实频率
+    amp_true = 2.5  # 真实振幅
+    phase_true = np.pi / 3  # 真实相位
+    snr_db = 40  # 信噪比(dB)
+
+    # 生成时间向量
+    t = np.arange(N0) / fs
+
+    # 生成干净信号
+    x_clean = amp_true * np.sin(2 * np.pi * f_true * t + phase_true)
+
+    # 添加噪声
+    snr = 10 ** (snr_db / 10)
+    signal_power = np.sum(x_clean ** 2) / N0
+    noise_power = signal_power / snr
+    noise = np.sqrt(noise_power) * np.random.randn(N0)
+    x = x_clean + noise
+
+    # 创建STEPS估计器并进行参数估计
+    steps = STEPS(fs, N0)
+    f_est, amp_est, phase_est = steps.estimate(x)
+
+    # 计算误差
+    f_error = np.abs(f_est - f_true)
+    amp_error = np.abs(amp_est - amp_true)
+    phase_error = np.abs(phase_est - phase_true)
+
+    # 显示结果
+    print(f"真实频率: {f_true:.4f} Hz, 估计频率: {f_est:.4f} Hz, 误差: {f_error:.6f} Hz")
+    print(f"真实振幅: {amp_true:.4f}, 估计振幅: {amp_est:.4f}, 误差: {amp_error:.6f}")
+    print(f"真实相位: {phase_true:.4f} rad, 估计相位: {phase_est:.4f} rad, 误差: {phase_error:.6f} rad")
+
+    # 绘制信号和频谱
+    plt.figure(figsize=(12, 8))
+
+    # 设置中文字体支持
+    plt.rcParams["font.family"] = ["SimHei"]
+    plt.rcParams["axes.unicode_minus"] = False  # 解决负号显示问题
+
+    # 绘制时域信号
+    plt.subplot(2, 1, 1)
+    plt.plot(t, x_clean, label='干净信号')
+    plt.plot(t, x, alpha=0.7, label='带噪声信号')
+    plt.xlabel('时间 (s)')
+    plt.ylabel('振幅')
+    plt.title('时域信号')
+    plt.legend()
+
+    # 绘制频谱
+    plt.subplot(2, 1, 2)
+    freq = np.fft.fftfreq(N0, 1 / fs)
+    x_fft = np.fft.fft(x)
+    plt.plot(freq[:N0 // 2], 20 * np.log10(np.abs(x_fft[:N0 // 2])), label='信号频谱')
+    plt.xlabel('频率 (Hz)')
+    plt.ylabel('幅度 (dB)')
+    plt.title('信号频谱')
+    plt.xlim(0, 150)  # 限制频率范围以便观察
+    plt.axvline(f_true, color='r', linestyle='--', label=f'真实频率: {f_true} Hz')
+    plt.axvline(f_est, color='g', linestyle='--', label=f'估计频率: {f_est:.2f} Hz')
+    plt.legend()
+
+    plt.tight_layout()
+    plt.show()

+ 40 - 10
test_pytorch.py

@@ -5,20 +5,24 @@ import numpy as np
 import matplotlib.pyplot as plt
 
 
-print(torch.__version__)  # pytorch版本
-print(torch.version.cuda)  # cuda版本
+print(f"torch.__version__:{torch.__version__}")  # pytorch版本
+print(f"torch.version.cuda:{torch.version.cuda}")  # cuda版本
 print(torch.cuda.is_available())  # 查看cuda是否可用
 
 #
 # 使用GPU or CPU
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-
+print(f"using device:{device}")
 
 # 生成一些随机的训练数据
 np.random.seed(42)
 x = np.random.rand(1000, 1)
-y = 2 * x + 1 + 0.5 * np.random.randn(1000, 1)
+y = 10 * x + 1 + 0.5 * np.random.randn(1000, 1)
+# y[900]=1000
+
+
+
 
 # 将数据转换为张量
 x_tensor = torch.from_numpy(x).float()
@@ -41,7 +45,9 @@ criterion = nn.MSELoss()
 optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
 
 # 训练模型
-num_epochs = 1000
+num_epochs = 6000
+dbug_epos=np.zeros(num_epochs)
+dbug_loss=np.zeros(num_epochs)
 for epoch in range(num_epochs):
     # 前向传播
     outputs = model(x_tensor)
@@ -52,6 +58,10 @@ for epoch in range(num_epochs):
     loss.backward()
     optimizer.step()
 
+    dbug_epos[epoch]=epoch
+    dbug_loss[epoch] = loss.item()
+
+
     if (epoch + 1) % 10 == 0:
         print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}')
 
@@ -60,9 +70,29 @@ with torch.no_grad():
     predicted = model(x_tensor)
 
 # 绘制原始数据和预测结果
-plt.scatter(x, y, label='Original Data')
-plt.plot(x, predicted.numpy(), color='red', label='Predicted Line')
-plt.xlabel('x')
-plt.ylabel('y')
-plt.legend()
+
+# plt.scatter(x, y, label='Original Data')
+# plt.plot(x, predicted.numpy(), color='red', label='Predicted Line')
+# plt.xlabel('x')
+# plt.ylabel('y')
+# plt.legend()
+# plt.show()
+
+fig1, ax1 = plt.subplots()
+ax1.scatter(x, y, label='Original Data')
+ax1.plot(x, predicted.numpy(), color='red', label='Predicted Line')
+ax1.set_xlabel('x')
+ax1.set_ylabel('y')
+ax1.legend()
+
+
+fig2, ax2 = plt.subplots()
+ax2.scatter(dbug_epos, dbug_loss, label='loss')
+ax2.plot(dbug_epos, dbug_loss, color='red', label='loss')
+ax2.set_xlabel('dbug_epos')
+ax2.set_ylabel('dbug_loss')
+ax2.legend()
+
+
+
 plt.show()

+ 365 - 0
test_pytorch_MNIST.py

@@ -0,0 +1,365 @@
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torchvision import datasets, transforms
+from torch.utils.data import DataLoader
+import numpy as np
+import tkinter as tk
+from PIL import Image, ImageDraw
+from tkinter import messagebox
+import cv2
+from torchsummary import summary
+from torchviz import make_dot
+import netron
+
+
+
+workmode = 1  #0:训练  1:加载模型
+
+
+# 数据预处理
+if  workmode==0:
+    transform = transforms.Compose([
+        # transforms.RandomRotation(10),  # 随机旋转 10 度
+        # transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 随机平移
+        transforms.ToTensor(),
+        transforms.Normalize((0.1307,), (0.3081,))
+    ])
+else:
+    transform = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize((0.1307,), (0.3081,))
+    ])
+
+
+# 加载训练集和测试集
+train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
+test_dataset = datasets.MNIST('data', train=False, transform=transform)
+
+train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
+test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
+
+test_loader_predicted=np.zeros(len(test_loader.dataset))
+
+
+def showfig(images,labels,totalnum,num):
+    # 选择要显示的图片数量
+    num_images = num
+
+    # 创建一个子图布局
+    fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
+
+    # 遍历数据集并显示图片和标签
+    for i in range(num_images):
+        # image, label = train_dataset[i]
+        random_numbers = np.random.choice(totalnum+1, num_images, replace=False)
+        # image = train_dataset.train_data[random_numbers[i]]
+        # label = train_dataset.train_labels[random_numbers[i]]
+        image = images[random_numbers[i]]
+        label = labels[random_numbers[i]]
+        # 将张量转换为numpy数组并调整维度
+        image = image.squeeze().numpy()
+        # 显示图片
+        axes[i].imshow(image, cmap='gray')
+        # 设置标题为标签
+        axes[i].set_title(f'idx-{random_numbers[i]}-Label: {label}')
+        axes[i].axis('off')
+    # # # 显示图形
+    # plt.show()
+
+
+# while 1:
+#     pass
+showfig(train_dataset.data,train_dataset.targets,60000,4)
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+# # 定义神经网络模型
+# class Net(nn.Module):
+#     def __init__(self):
+#         super(Net, self).__init__()
+#         self.fc1 = nn.Linear(784, 128)
+#         self.fc2 = nn.Linear(128, 64)
+#         self.fc3 = nn.Linear(64, 10)
+#
+#     def forward(self, x):
+#         x = x.view(-1, 784)
+#         x = torch.relu(self.fc1(x))
+#         x = torch.relu(self.fc2(x))
+#         x = self.fc3(x)
+#         return x
+# model = Net()
+# criterion = nn.CrossEntropyLoss()
+# optimizer = optim.SGD(model.parameters(), lr=0.01)
+
+
+
+# # 定义修改后的 AlexNet 模型,适应 MNIST 数据集
+# class AlexNet(nn.Module):
+#     def __init__(self, num_classes=10):
+#         super(AlexNet, self).__init__()
+#         self.features = nn.Sequential(
+#             nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 修改卷积核大小和通道数
+#             nn.ReLU(inplace=True),
+#             nn.MaxPool2d(kernel_size=2, stride=2),  # 修改池化核大小和步长
+#             nn.Conv2d(32, 64, kernel_size=3, padding=1),
+#             nn.ReLU(inplace=True),
+#             nn.MaxPool2d(kernel_size=2, stride=2),
+#             nn.Conv2d(64, 128, kernel_size=3, padding=1),
+#             nn.ReLU(inplace=True),
+#             nn.Conv2d(128, 128, kernel_size=3, padding=1),
+#             nn.ReLU(inplace=True),
+#             nn.Conv2d(128, 128, kernel_size=3, padding=1),
+#             nn.ReLU(inplace=True),
+#             nn.MaxPool2d(kernel_size=2, stride=2),
+#         )
+#         self.classifier = nn.Sequential(
+#             nn.Dropout(),
+#             nn.Linear(128 * 3 * 3, 128),  # 修改全连接层输入维度
+#             nn.ReLU(inplace=True),
+#             nn.Dropout(),
+#             nn.Linear(128, 128),
+#             nn.ReLU(inplace=True),
+#             nn.Linear(128, num_classes),
+#         )
+#
+#     def forward(self, x):
+#         x = self.features(x)
+#         x = x.view(x.size(0), 128 * 3 * 3)  # 修改展平后的维度
+#         x = self.classifier(x)
+#         return x
+#
+
+#
+# # 初始化模型、损失函数和优化器
+# model = AlexNet().to(device)
+
+
+# LeNet-5模型定义
+class LeNet5(nn.Module):
+    def __init__(self):
+        super(LeNet5, self).__init__()
+        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
+        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
+        self.fc1 = nn.Linear(16 * 4 * 4, 120)
+        self.fc2 = nn.Linear(120, 84)
+        self.fc3 = nn.Linear(84, 10)
+
+    def forward(self, x):
+        x = torch.relu(self.conv1(x))
+        x = nn.MaxPool2d(2)(x)
+        x = torch.relu(self.conv2(x))
+        x = nn.MaxPool2d(2)(x)
+        x = x.view(x.size(0), -1)
+        x = torch.relu(self.fc1(x))
+        x = torch.relu(self.fc2(x))
+        x = self.fc3(x)
+        return x
+
+# 实例化模型
+model = LeNet5()
+
+
+
+
+
+
+
+criterion = nn.CrossEntropyLoss()
+optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
+
+
+# # 打印模型结构
+# summary(model, input_size=(1, 28, 28))
+
+# # 创建一个随机输入张量
+# x = torch.randn(1, 1, 28, 28).to(device)
+# # 前向传播
+# y = model(x).to(device)
+# # 使用 torchviz 生成计算图
+# dot = make_dot(y, params=dict(model.named_parameters()))
+# # 保存计算图为图像文件(这里保存为 PNG 格式)
+# dot.render('alexnet_model', format='png', cleanup=True, view=True)
+
+# 训练过程
+def train(model, train_loader, optimizer, criterion, epoch):
+    model.train()
+    for batch_idx, (data, target) in enumerate(train_loader):
+        optimizer.zero_grad()
+        output = model(data)
+        loss = criterion(output, target)
+        loss.backward()
+        optimizer.step()
+        if batch_idx % 100 == 0:
+            print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
+                epoch, batch_idx * len(data), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), loss.item()))
+
+# 验证过程
+
+
+def test(model, test_loader):
+    model.eval()
+    correct = 0
+    total = 0
+    with torch.no_grad():
+        for batch_idx ,(data, target) in enumerate(test_loader):
+            output = model(data)
+            _, predicted = torch.max(output.data, 1)
+            test_loader_predicted[batch_idx*64+0:batch_idx*64+64]=predicted.numpy()
+            total += target.size(0)
+            correct += (predicted == target).sum().item()
+
+    accuracy = 100 * correct / total
+    print('Test Accuracy: {:.2f}%'.format(accuracy))
+
+# # 训练和验证模型
+
+if workmode==0:  #0:训练
+    for epoch in range(10):
+        train(model, train_loader, optimizer, criterion, epoch)
+        test(model, test_loader)
+
+    torch.save(model, 'model.pth')
+    print(f'save model as model.pth')
+else:            #1:加载模型
+    print(f'load model.pth')
+    try:
+        model = torch.load('model.pth', weights_only=False)
+        model.eval()
+    except Exception as e:
+        print(f"加载模型时出现错误: {e}")
+    test(model, test_loader)
+
+    # netron.start('model.pth')  # 输出网络结构图
+
+showfig(test_loader.dataset.data, test_loader_predicted, 10000, 4)
+plt.show()
+
+
+def save_drawing():
+    global drawing_points
+    # 创建一个空白图像
+    image = Image.new("RGB", (canvas.winfo_width(), canvas.winfo_height()), "black")
+    draw = ImageDraw.Draw(image)
+
+    # 绘制线条
+    for i in range(1, len(drawing_points)):
+        x1,y1=drawing_points[i - 1]
+        x2, y2 = drawing_points[i]
+        if (x1 is not None) and (x2 is not None) and (y1 is not None) and (y2 is not None):
+            draw.line((x1, y1, x2, y2), fill="white", width=20)
+
+    image = image.convert('L')
+    image1=image.resize((28,28))
+    # # 4. 转换为 numpy 数组
+    # image_array = np.array(image1)
+    #
+    # # 5. 二值化
+    # _, binary_image = cv2.threshold(image_array, 127, 255, cv2.THRESH_BINARY)
+    #
+    # # 6. 居中处理
+    # rows, cols = binary_image.shape
+    # M = cv2.moments(binary_image)
+    # if M["m00"] != 0:
+    #     cX = int(M["m10"] / M["m00"])
+    #     cY = int(M["m01"] / M["m00"])
+    # else:
+    #     cX, cY = 0, 0
+    # shift_x = cols / 2 - cX
+    # shift_y = rows / 2 - cY
+    # M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
+    # centered_image = cv2.warpAffine(binary_image, M, (cols, rows))
+    #
+    # # 7. 归一化
+    # normalized_image = centered_image / 255.0
+    #
+    # # 8. 调整维度以适应模型输入
+    # final_image = normalized_image.reshape(28, 28)
+    # image1 = Image.fromarray(final_image)
+
+
+
+    # # 转换为numpy数组
+    # img_array = np.array(image1)
+    # # 中值滤波
+    # filtered_img = cv2.medianBlur(img_array, 3)
+    # # 转换回Image对象(如果需要的话)
+    # image1 = Image.fromarray(filtered_img)
+
+    tensor_image = transform(image1) #torch.Size([3, 28, 28])
+
+    # gray_tensor = torch.mean(tensor_image, dim=0, keepdim=True)
+    # pool = torch.nn.MaxPool2d(kernel_size=10, stride=10)
+    # pooled_image = pool(gray_tensor.unsqueeze(0)).squeeze(0)
+
+    # pooled_image = gray_tensor
+    pooled_image = tensor_image.unsqueeze(0)
+
+
+    # print(f'tensor_image :{gray_tensor.shape} -pooled_image:{pooled_image.shape}')
+
+    # simage=pooled_image.view(28,28)
+    # simage = (simage - simage.min()) / (simage.max() - simage.min())
+    # np_array = (simage.numpy() * 255).astype('uint8')
+    # image_f = Image.fromarray(np_array)
+    # image_f.show()
+
+
+
+    with torch.no_grad():
+        output = torch.softmax(model(pooled_image),1)
+        # print(f'output.data={output}')
+        v, predicted = torch.max(output, 1)
+
+    print(f'预测数字={predicted.numpy()[0]},概率:{(v*100).numpy()[0]:.2f}%')
+    messagebox.showinfo('识别结果',f'predicted={predicted.numpy()[0]}')
+    drawing_points=[]
+    canvas.delete("all")
+
+    # 保存图像
+    image.save("drawing.png")
+    image1.save("drawing28x28.png")
+    # print("绘画已保存为 drawing.png")
+
+last_x,last_y=None,None
+# last_y=[]
+def on_mouse_move(event):
+    global last_x,last_y
+
+    # drawing_points.append((event.x, event.y))
+    drawing_points.append((last_x, last_y))
+    if (last_x is not None) and  (last_y is not None) :
+        canvas.create_line(last_x, last_y, event.x, event.y, fill="white", width=20, smooth=True, splinesteps=10)
+        # canvas.create_line(last_x, last_y , event.x, event.y, fill="white", width=20)
+
+    last_x, last_y = event.x, event.y
+
+
+def on_mouse_release(event):
+    global last_x, last_y
+    last_x, last_y = None, None
+    # print("on_mouse_release")
+    pass
+
+
+
+
+
+root = tk.Tk()
+
+canvas = tk.Canvas(root, width=280*2, height=280*2, bg="black")
+canvas.pack()
+
+# canvas_show = tk.Canvas(root, width=280, height=280, bg="black")
+# canvas_show.pack()
+
+button = tk.Button(root, text="识别", command=save_drawing)
+button.pack()
+
+drawing_points = []
+canvas.bind("<B1-Motion>", on_mouse_move)
+canvas.bind("<ButtonRelease-1>", on_mouse_release)
+
+root.mainloop()

+ 2 - 1
test_pytorch_cuda.py

@@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
 
 # 检查是否有可用的GPU
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+print(f"using device:{device}")
 
 # 定义数据预处理步骤
 transform = transforms.Compose([
@@ -49,7 +50,7 @@ optimizer = optim.SGD(model.parameters(), lr=0.01)
 # 训练模型
 def train_model():
     model.train()
-    for epoch in range(20):
+    for epoch in range(5):  # should be 20
         running_loss = 0.0
         for batch_idx, (data, target) in enumerate(train_loader):
             data, target = data.to(device), target.to(device)

+ 112 - 0
test_pytorch_lstm.py

@@ -0,0 +1,112 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import matplotlib.pyplot as plt
+
+# 设置随机种子以确保结果可重复
+np.random.seed(0)
+torch.manual_seed(0)
+
+# 1. 生成随机一位序列
+sequence_length = 100
+# random_sequence = np.random.randint(0, 200, size=sequence_length)
+random_sequence = np.linspace(0, 200, sequence_length)
+# 将序列保存到文件
+np.savetxt('random_sequence.csv', random_sequence, delimiter=',')
+
+# 2. 定义 LSTM 模型
+class LSTMModel(nn.Module):
+    def __init__(self, input_size, hidden_size, num_layers, output_size):
+        super(LSTMModel, self).__init__()
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
+        self.fc = nn.Linear(hidden_size, output_size)
+
+    def forward(self, x):
+        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
+        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
+        out, _ = self.lstm(x, (h0, c0))
+        out = self.fc(out[:, -1, :])
+        return out
+
+# 3. 准备数据
+def create_sequences(data, seq_length):
+    xs, ys = [], []
+    for i in range(len(data) - seq_length):
+        x = data[i:i+seq_length]
+        y = data[i+seq_length]
+        xs.append(x)
+        ys.append(y)
+    return np.array(xs), np.array(ys)
+
+# 参数设置
+input_size = 1
+hidden_size = 32
+num_layers = 2
+output_size = 1
+seq_length = 10
+
+# 创建序列
+X, y = create_sequences(random_sequence, seq_length)
+
+# 转换为 PyTorch 张量
+X = torch.tensor(X, dtype=torch.float32).unsqueeze(2)
+y = torch.tensor(y, dtype=torch.float32).unsqueeze(1)
+
+# 4. 训练 LSTM 模型
+model = LSTMModel(input_size, hidden_size, num_layers, output_size)
+criterion = nn.MSELoss()
+optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
+
+num_epochs = 100*100
+for epoch in range(num_epochs):
+    model.train()
+    outputs = model(X)
+    loss = criterion(outputs, y)
+
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+
+    if (epoch+1) % 10 == 0:
+        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
+
+# 5. 生成预测结果
+model.eval()
+with torch.no_grad():
+    # 初始化 test_input 为前 seq_length 个序列
+    test_input = X[0].unsqueeze(0)  # 形状 (1, 5, 1)
+    predicted = []
+    for _ in range(sequence_length - seq_length):
+        output = model(test_input)
+        predicted.append(output.item())
+        # 更新 test_input: 移除第一个时间步并添加预测的输出
+        test_input = torch.cat((test_input[:, 1:], output.view(1, 1, 1)), dim=1)
+
+    predicted = np.array(predicted)
+
+np.savetxt('random_sequence_predicted.csv', predicted, delimiter=',')
+
+# 6. 使用 matplotlib 绘制对比图
+plt.figure(figsize=(12, 6))
+
+# 绘制原始序列
+plt.subplot(1, 2, 1)
+plt.scatter(range(sequence_length), random_sequence, label='Original Sequence', color='blue', s=50)
+plt.title('原始序列')
+plt.xlabel('时间步')
+plt.ylabel('值')
+plt.legend()
+
+# 绘制预测序列
+plt.subplot(1, 2, 2)
+plt.scatter(range(seq_length, sequence_length), predicted, label='Predicted Sequence', color='red', s=50)
+plt.title('预测序列')
+plt.xlabel('时间步')
+plt.ylabel('值')
+plt.legend()
+
+plt.tight_layout()
+plt.show()

+ 43 - 0
test_supervision.py

@@ -0,0 +1,43 @@
+import supervision as sv
+import cv2
+import os
+from inference import get_model
+from ultralytics import YOLO
+
+
+print(sv.__version__)
+
+HOME = os.getcwd()
+# print(HOME)
+IMAGE_PATH = f"{HOME}//images//dog-3.jpeg"
+image = cv2.imread(IMAGE_PATH)
+
+# # 修正 cv2.imshow() 函数调用,添加窗口名称
+# cv2.imshow('Display Image', image)
+# # 等待按键事件,防止窗口立即关闭
+# cv2.waitKey(0)
+# # 关闭所有 OpenCV 窗口
+# cv2.destroyAllWindows()
+
+# model = get_model(model_id="yolov8s-640")
+# result = model.infer(image)[0]
+# detections = sv.Detections.from_inference(result)
+
+
+
+model = YOLO("yolov8s.pt")
+result = model(image, verbose=False)[0]
+detections = sv.Detections.from_ultralytics(result)
+
+# print(detections)
+box_annotator = sv.BoxAnnotator()
+label_annotator = sv.LabelAnnotator()
+
+annotated_image = image.copy()
+annotated_image = box_annotator.annotate(annotated_image, detections=detections)
+annotated_image = label_annotator.annotate(annotated_image, detections=detections)
+
+sv.plot_image(image=annotated_image, size=(8, 8))
+
+print('end of file')
+

+ 20 - 0
test_supervision_videos.py

@@ -0,0 +1,20 @@
+from supervision.assets import download_assets, VideoAssets
+import supervision as sv
+import cv2
+import os
+from inference import get_model
+from ultralytics import YOLO
+
+print(sv.__version__)
+HOME = os.getcwd()
+
+download_assets(VideoAssets.VEHICLES)
+VIDEO_PATH = VideoAssets.VEHICLES.value
+
+sv.VideoInfo.from_video_path(video_path=VIDEO_PATH)
+frame_generator = sv.get_video_frames_generator(source_path=VIDEO_PATH)
+frame = next(iter(frame_generator))
+sv.plot_image(image=frame, size=(4, 4))
+
+RESULT_VIDEO_PATH = f"{HOME}/videos/vehicle-counting-result.mp4"
+