2 Комити e1378e7a54 ... cb695c6304

Аутор SHA1 Порука Датум
  alex cb695c6304 添加新的测试代码 пре 1 месец
  alex ad84308fc4 添加test_ai测试代码 пре 6 месеци
20 измењених фајлова са 1635 додато и 11 уклоњено
  1. 7 0
      .idea/MarsCodeWorkspaceAppSettings.xml
  2. 6 0
      .idea/PythonProject.iml
  3. 4 0
      .idea/encodings.xml
  4. 3 0
      1.TXT
  5. 10 0
      MZ/.idea/MZ.iml
  6. 77 0
      example.py
  7. 35 0
      showhtml.py
  8. 135 0
      testUI.py
  9. 231 0
      test_3D.py
  10. 87 0
      test_RGSGQF_filter.py
  11. 218 0
      test_STEPS.py
  12. 121 0
      test_ai.py
  13. 24 0
      test_gym.py
  14. 40 10
      test_pytorch.py
  15. 365 0
      test_pytorch_MNIST.py
  16. 2 1
      test_pytorch_cuda.py
  17. 112 0
      test_pytorch_lstm.py
  18. 43 0
      test_supervision.py
  19. 20 0
      test_supervision_videos.py
  20. 95 0
      test_web_auto.py

+ 7 - 0
.idea/MarsCodeWorkspaceAppSettings.xml

@@ -0,0 +1,7 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="com.codeverse.userSettings.MarscodeWorkspaceAppSettingsState">
+    <option name="ckgOperationStatus" value="SUCCESS" />
+    <option name="progress" value="0.77272725" />
+  </component>
+</project>

+ 6 - 0
.idea/PythonProject.iml

@@ -7,4 +7,10 @@
     <orderEntry type="jdk" jdkName="Python 3.11 (PythonProject)" jdkType="Python SDK" />
     <orderEntry type="sourceFolder" forTests="false" />
   </component>
+  <component name="PackageRequirementsSettings">
+    <option name="requirementsPath" value="" />
+  </component>
+  <component name="TestRunnerService">
+    <option name="PROJECT_TEST_RUNNER" value="Unittests" />
+  </component>
 </module>

+ 4 - 0
.idea/encodings.xml

@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="Encoding" defaultCharsetForPropertiesFiles="UTF-8" />
+</project>

+ 3 - 0
1.TXT

@@ -0,0 +1,3 @@
+我将帮助你在
+testUI.py
+ 中实现点击第一个菜单下的第一个子菜单时读取 1.txt 文件,并将文件内容显示在主界面上。以下是完整的代码示例:

+ 10 - 0
MZ/.idea/MZ.iml

@@ -0,0 +1,10 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$">
+      <excludeFolder url="file://$MODULE_DIR$/venv" />
+    </content>
+    <orderEntry type="inheritedJdk" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+</module>

+ 77 - 0
example.py

@@ -0,0 +1,77 @@
+import tkinter as tk
+from tkhtmlview import HTMLScrolledText
+
+root = tk.Tk()
+root.geometry("780x640")
+html_label = HTMLScrolledText(
+    root,
+    html="""
+      <h1 style="color: red; text-align: center"> Hello World </h1>
+      <h2> Smaller </h2>
+      <h3> Little Smaller </h3>
+      <h4> Little Little Smaller </h4>
+      <h5> Little Little Little Smaller </h5>
+      <h6> Little Little Little Little Smaller </h6>
+      <b>This is a Bold text</b><br/>
+      <strong>This is an Important text</strong><br/>
+      <i>This is a Italic text</i><br/>
+      <em>This is a Emphasized text</em><br/>
+      <em>This is a <strong>Strong Emphasized</strong>   text   </em>.<br/>
+      <img src="https://www.google.com/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png"></img>
+      <ul>
+      <li> One </li>
+        <ul>
+            <li> One.Zero </li>
+            <li> One.One </li>
+        </ul>
+      <li> Two </li>
+      <li> Three </li>
+      <li> Four </li>
+      </ul>
+
+      <h3> Paragraph </h3>
+      <p>Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed ut quam sapien. Maecenas porta tempus mauris sed ullamcorper. Nulla facilisi. Nulla facilisi. Mauris tristique ipsum et efficitur lobortis. Sed pharetra ipsum non lacinia dignissim. Ut condimentum vulputate sem eget scelerisque. Curabitur ornare augue enim, sed volutpat enim finibus id. </p>
+
+      <h3>Table</h3
+      <table>
+          <tr>
+            <th style="background-color: silver">T Header 1</th>
+            <th style="background-color: silver">T Header 2</th>
+          </tr>
+          <tr>
+            <td>ABC</td>
+            <td><em>123</em></td>
+          </tr>
+          <tr>
+            <td>DEF</td>
+            <td><i>456</i></td>
+          </tr>
+      </table>
+
+      <h3> Preformatted text: spaces and newlines preserved </h3>
+      <pre style="font-family: Consolas; font-size: 80%">
+<b>Tags      Attributes          Notes</b>
+a         style, <span style="color: RoyalBlue">href         </span>-
+img       <u>src</u>, width, height  <i>experimental</i>
+ol        <span style="color: Tomato">style, type         </span>-
+ul        <i>style               </i>bullet glyphs only
+div       style               -
+      </pre>
+
+      <h3> Code: spaces and newlines ignored </h3>
+      <code style="font-family: Consolas; font-size: 80%">
+<b>Tags      Attributes          Notes</b><br/>
+a         style, <span style="color: RoyalBlue">href         </span>-
+img       <u>src</u>, width, height  <i>experimental</i>
+ol        <span style="color: Tomato">style, type         </span>-
+ul        <i>style               </i>bullet glyphs only
+div       style               -
+      </code>
+
+      <p>Inline <code>code blocks</code> displayed without newlines</p>
+
+    """,
+)
+html_label.pack(fill="both", expand=True)
+html_label.fit_height()
+root.mainloop()

+ 35 - 0
showhtml.py

@@ -0,0 +1,35 @@
+import os
+import tkinter as tk
+from tkinterweb import HtmlFrame
+
+
+def main():
+    root = tk.Tk()
+    root.title("设置基准路径示例")
+    root.geometry("800x600")
+
+    # 创建HTML浏览器组件
+    browser = HtmlFrame(root)
+    browser.pack(fill="both", expand=True)
+
+    # HTML文件路径(请替换为你的实际路径)
+    html_file = "data\\2.htm"
+
+    # 计算HTML文件所在的目录
+    html_dir = os.path.dirname(os.path.abspath(html_file))
+
+    # 读取HTML内容
+    with open(html_file, "r", encoding="utf-8") as file:
+        html_content = file.read()
+
+    # 设置基准路径(使用file://协议)
+    # base_uri = f"file://{html_dir}/"
+
+    # 加载HTML内容并设置基准路径
+    browser.add_html(html_content)
+
+    root.mainloop()
+
+
+if __name__ == "__main__":
+    main()

+ 135 - 0
testUI.py

@@ -0,0 +1,135 @@
+import ttkbootstrap as ttk
+import os
+import tkinter as tk
+from tkhtmlview import HTMLLabel
+# from tkinterweb import HtmlFrame
+
+
+def on_menu_item_click(label):
+    if label == "Item 1":
+        display_html_file("data/1.html")
+    elif label == "Item 2":
+        display_html_file("data/2.html")
+    elif label == "Sub Item 1":
+        display_text_and_entries()
+    elif label == "about":
+        display_html_file("data/about.html")
+    else:
+        print(f"Clicked on {label}")
+
+def display_html_file(filepath):
+    # 清空当前窗口内容,但保留菜单栏
+    for widget in window.winfo_children():
+        if not isinstance(widget, ttk.Menu):
+            widget.destroy()
+
+    if os.path.exists(filepath):
+        with open(filepath, 'r', encoding='utf-8') as file:
+            html_content = file.read()
+
+        # 创建 HTMLLabel 来显示 HTML 内容
+        global html_label
+        html_label = HTMLLabel(window, html=html_content,wrap=ttk.CHAR)
+        # tkinterweb
+        # html_label = HtmlFrame(window)
+        #
+        # html_label.load_file("C:\\Users\\wang\\PycharmProjects\\PythonProject\\data\\1.html")
+
+        # html_label.add_html(html_content)
+        #html_label.load_html(html_content,base_url="C:\\Users\\wang\\PycharmProjects\\PythonProject\\data")
+
+
+        html_label.pack(side=ttk.TOP, fill=ttk.BOTH, expand=True)
+        html_label.fit_height()
+    else:
+        # 创建文本区域用于显示错误信息
+        global text_area
+        text_area = ttk.Text(window, wrap=ttk.WORD, font=('Arial', 12))
+        text_area.pack(side=ttk.TOP, fill=ttk.BOTH, expand=True)
+        text_area.insert(ttk.END, f"File {filepath} not found.")
+
+def display_text_and_entries():
+    # 清空当前窗口内容,但保留菜单栏
+    for widget in window.winfo_children():
+        if not isinstance(widget, ttk.Menu):
+            widget.destroy()
+
+    global entry_list
+    entry_list = []
+    global label_list
+    label_list = []
+
+    # 创建10行界面
+    for i in range(1, 11):
+        frame = ttk.Frame(window)
+        frame.pack(pady=5, fill=ttk.X)
+
+        title_label = ttk.Label(frame, text=f"Title {i}", font=('Arial', 12))
+        title_label.pack(side=ttk.LEFT, padx=5)
+
+        entry = ttk.Entry(frame, font=('Arial', 12), width=20)
+        entry.pack(side=ttk.LEFT, padx=5)
+
+        entry_list.append(entry)
+
+        # 添加一个新的Label来显示内容
+        label = ttk.Label(frame, text=f"text {i}", font=('Arial', 12))
+        label.pack(side=ttk.LEFT, padx=5)
+
+        label_list.append(label)
+
+    # 创建按钮
+    fill_button = ttk.Button(window, text="Fill All with Hello", command=fill_entries_with_hello)
+    fill_button.pack(pady=10)
+
+def fill_entries_with_hello():
+    for entry in entry_list:
+        entry.delete(0, ttk.END)
+        entry.insert(0, "Hello")
+
+def display_text_area():
+    # 清空当前窗口内容,但保留菜单栏
+    for widget in window.winfo_children():
+        if not isinstance(widget, ttk.Menu):
+            widget.destroy()
+
+    global text_area
+    text_area = ttk.Text(window, wrap=ttk.WORD, font=('Arial', 12))
+    text_area.pack(side=ttk.TOP, fill=ttk.BOTH, expand=True)
+
+# 创建窗口并设置主题为 'flatly'
+window = ttk.Window(themename='flatly')
+
+# 设置窗口大小为 2000x1200
+window.geometry("2000x1200")
+
+# 初始显示 text_area
+display_text_area()
+
+# 创建菜单栏
+menu_bar = ttk.Menu(window)
+
+# 创建第一个菜单及其子菜单
+first_menu = ttk.Menu(menu_bar, tearoff=0)
+for i in range(1, 11):
+    menu_label = f"Item {i}"
+    first_menu.add_command(label=menu_label, command=lambda label=menu_label: on_menu_item_click(label))
+menu_bar.add_cascade(label="Menu 1", menu=first_menu)
+
+# 创建第二个菜单及其子菜单
+second_menu = ttk.Menu(menu_bar, tearoff=0)
+second_menu.add_command(label="Sub Item 1", command=lambda: on_menu_item_click("Sub Item 1"))
+second_menu.add_command(label="Sub Item 2", command=lambda: on_menu_item_click("Sub Item 2"))
+menu_bar.add_cascade(label="Menu 2", menu=second_menu)
+
+# 创建第三个菜单及其子菜单
+third_menu = ttk.Menu(menu_bar, tearoff=0)
+third_menu.add_command(label="关于本软件", command=lambda: on_menu_item_click("about"))
+#third_menu.add_command(label="Sub Item 2", command=lambda: on_menu_item_click("Sub Item 2"))
+menu_bar.add_cascade(label="关于", menu=third_menu)
+
+# 设置窗口的菜单栏
+window.config(menu=menu_bar)
+
+# 运行窗口主循环
+window.mainloop()

+ 231 - 0
test_3D.py

@@ -0,0 +1,231 @@
+import vtk
+
+class ButtonCallback(object):
+    def __init__(self, renderer, actor1):
+        self.renderer = renderer
+        self.actor1 = actor1
+
+    def execute(self, obj, event):
+        print("Button was clicked!")
+        # Move actor1
+        if self.actor1:
+            current_position = self.actor1.GetPosition()
+            self.actor1.SetPosition(current_position[0] + 0.1, current_position[1], current_position[2])
+            self.renderer.GetRenderWindow().Render()
+
+class MyInteractorStyle(vtk.vtkInteractorStyleTrackballCamera):
+    def __init__(self, parent=None):
+        super().__init__()
+        self.AddObserver("LeftButtonPressEvent", self.left_button_press)
+        self.actor1 = None
+        self.actor2 = None
+        self.label1 = None
+        self.label2 = None
+
+    def left_button_press(self, obj, event):
+        try:
+            click_pos = self.GetInteractor().GetEventPosition()
+            print(f'left_button_press({click_pos[0]}, {click_pos[1]})')
+
+            # Get the default renderer
+            renderer = self.GetDefaultRenderer()
+            if renderer is None:
+                print("Error: Default renderer is None")
+                return
+
+            # Create a picker and perform picking
+            picker = vtk.vtkPropPicker()
+            picker.Pick(click_pos[0], click_pos[1], 0, renderer)
+
+            # Get the picked actor
+            picked_actor = picker.GetActor()
+            if picked_actor is None:
+                print("No actor was picked")
+            elif picked_actor == self.actor1:
+                print("Actor 1 was picked")
+                # Define the movement of actor2
+                self.actor2.SetPosition(self.actor2.GetPosition()[0] + 0.1,
+                                       self.actor2.GetPosition()[1],
+                                       self.actor2.GetPosition()[2])
+            elif picked_actor == self.actor2:
+                print("Actor 2 was picked")
+                # Define the movement of actor1
+                self.actor1.SetPosition(self.actor1.GetPosition()[0] + 0.1,
+                                       self.actor1.GetPosition()[1],
+                                       self.actor1.GetPosition()[2])
+            else:
+                print("Some other actor was picked")
+
+            self.GetInteractor().Render()
+        except Exception as e:
+            print(f"An error occurred: {e}")
+
+def create_button(renderer, callback, position, size, label):
+    # Create a TextActor to display the button text
+    text_actor = vtk.vtkTextActor()
+    text_property = vtk.vtkTextProperty()
+    text_property.SetFontFamilyToArial()
+    text_property.SetFontSize(12)
+    text_property.BoldOn()
+    text_property.ItalicOff()
+    text_property.ShadowOn()
+    text_property.SetJustificationToCentered()
+    text_property.SetColor(0, 0, 0)  # Text color
+    text_actor.SetInput(label)
+    text_actor.SetTextProperty(text_property)
+
+    # Calculate text position, center it
+    text_actor.SetDisplayPosition(position[0] + size[0] // 2, position[1] + size[1] // 2 - 10)
+
+    # Create an ImageData to create the background image
+    width, height = size
+    image_data = vtk.vtkImageData()
+    image_data.SetDimensions(width, height, 1)
+    image_data.SetSpacing(1, 1, 1)
+    image_data.SetOrigin(0, 0, 0)
+    image_data.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, 4)
+
+    colors = vtk.vtkNamedColors()
+    background_color = colors.GetColor4ub("LightGray")
+    highlight_color = colors.GetColor4ub("LightBlue")
+
+    # Set the default background color
+    scalars = image_data.GetPointData().GetScalars()
+    for y in range(height):
+        for x in range(width):
+            offset = y * width + x
+            scalars.SetTuple4(offset, background_color[0], background_color[1], background_color[2], background_color[3])
+
+    # Create an ImageMapper to create the background image
+    image_mapper = vtk.vtkImageMapper()
+    image_mapper.SetInputData(image_data)
+
+    # Create an ImageActor to display the background image
+    background_actor = vtk.vtkImageActor()
+    background_actor.SetMapper(image_mapper)
+
+    # Set background position
+    background_actor.SetDisplayPosition(position[0], position[1])
+
+    # Set interaction representation
+    button_rep = vtk.vtkButtonRepresentation2D()
+    button_rep.SetPlaceFactor(1.0)
+    button_rep.PlaceWidget(position, [size[0], size[1]])
+
+    # Create a ButtonWidget to manage button interaction
+    button = vtk.vtkButtonWidget()
+    button.SetInteractor(renderer.GetRenderWindow().GetInteractor())
+    button.SetRepresentation(button_rep)
+    button.AddObserver("StateChangedEvent", callback)
+
+    # Enable the button
+    button.On()
+
+    # Update background color to reflect button state
+    def update_button_state(obj, event):
+        state = button_rep.GetState()
+        if state == 0:  # Normal state
+            color = background_color
+        else:  # Highlight state
+            color = highlight_color
+
+        scalars = image_data.GetPointData().GetScalars()
+        for y in range(height):
+            for x in range(width):
+                offset = y * width + x
+                scalars.SetTuple4(offset, color[0], color[1], color[2], color[3])
+
+        image_data.Modified()
+        renderer.GetRenderWindow().Render()
+
+    button.AddObserver("StateChangedEvent", update_button_state)
+
+    # Add text and background to renderer
+    renderer.AddActor2D(text_actor)
+    renderer.AddActor2D(background_actor)
+
+    return button
+
+def read_and_display_stl(file_path1, file_path2):
+    # Read the first STL model
+    reader1 = vtk.vtkSTLReader()
+    reader1.SetFileName(file_path1)
+
+    mapper1 = vtk.vtkPolyDataMapper()
+    mapper1.SetInputConnection(reader1.GetOutputPort())
+
+    actor1 = vtk.vtkActor()
+    actor1.SetMapper(mapper1)
+
+    # Create label
+    text_property1 = vtk.vtkTextProperty()
+    text_property1.SetFontFamilyToArial()
+    text_property1.SetFontSize(16)
+    text_property1.BoldOn()
+    text_property1.ItalicOff()
+    text_property1.ShadowOn()
+
+    label1 = vtk.vtkTextActor()
+    label1.SetInput("Model 1")
+    label1.SetTextProperty(text_property1)
+    label1.SetDisplayPosition(50, 50)  # Set position
+
+    # Read the second STL model
+    reader2 = vtk.vtkSTLReader()
+    reader2.SetFileName(file_path2)
+
+    mapper2 = vtk.vtkPolyDataMapper()
+    mapper2.SetInputConnection(reader2.GetOutputPort())
+
+    actor2 = vtk.vtkActor()
+    actor2.SetMapper(mapper2)
+
+    # Create label
+    text_property2 = vtk.vtkTextProperty()
+    text_property2.SetFontFamilyToArial()
+    text_property2.SetFontSize(16)
+    text_property2.BoldOn()
+    text_property2.ItalicOff()
+    text_property2.ShadowOn()
+
+    label2 = vtk.vtkTextActor()
+    label2.SetInput("Model 2")
+    label2.SetTextProperty(text_property2)
+    label2.SetDisplayPosition(50, 100)  # Set position
+
+    renderer = vtk.vtkRenderer()
+    renderer.AddActor(actor1)
+    renderer.AddActor(actor2)
+    renderer.AddActor2D(label1)  # Add label to renderer
+    renderer.AddActor2D(label2)  # Add label to renderer
+
+    render_window = vtk.vtkRenderWindow()
+    render_window.AddRenderer(renderer)
+    render_window.SetSize(800, 600)  # Set window size to 800x600
+
+    interactor = vtk.vtkRenderWindowInteractor()
+    interactor.SetRenderWindow(render_window)
+
+    # Set custom interaction style
+    style = MyInteractorStyle()
+    style.actor1 = actor1
+    style.actor2 = actor2
+    style.label1 = label1
+    style.label2 = label2
+    interactor.SetInteractorStyle(style)
+
+    # Set default renderer
+    style.SetDefaultRenderer(renderer)
+
+    # Create button
+    callback = ButtonCallback(renderer, actor1)
+    button1 = create_button(renderer, callback.execute, (100, 550), (100, 30), "Move Actor 1")
+
+    interactor.Initialize()
+    interactor.Start()
+
+if __name__ == "__main__":
+    # Replace with your two STL file paths
+    stl_file_path1 = "tetrahedron.stl"
+    stl_file_path2 = "tetrahedron.stl"
+    read_and_display_stl(stl_file_path1, stl_file_path2)

+ 87 - 0
test_RGSGQF_filter.py

@@ -0,0 +1,87 @@
+import numpy as np
+
+# 定义权重函数
+def weight_function(zeta, k1=1.5, k2=5):
+    abs_zeta = np.abs(zeta)
+    psi = np.zeros_like(abs_zeta)
+    psi[abs_zeta <= k1] = 1
+    mask = (abs_zeta > k1) & (abs_zeta <= k2)
+    psi[mask] = (1 - ((abs_zeta[mask] - k1) ** 2 / (k2 - k1) ** 2)) ** 2
+    # 确保权重不为零
+    psi[psi < 1e-6] = 1e-6
+    # print(f"psi={psi}")
+    return psi
+
+# RESGQF滤波函数
+def resgqf_filter(x_hat, P, z, f, h, Q, R, k1=1.5, k2=5):
+    # 状态预测
+    x_pred = f(x_hat)
+    P_pred = P + Q
+
+    # 计算量测残差
+    z_hat = h(x_pred)
+    r = z - z_hat
+    # 标准化残差
+    R_inv_sqrt = np.linalg.inv(np.linalg.cholesky(R))
+    zeta = R_inv_sqrt @ r
+
+    # 计算权重矩阵
+    psi = weight_function(zeta)
+    Psi = np.diag(psi)
+
+    # 确保 Psi 是一个 2D 矩阵
+    if Psi.ndim == 1:
+        Psi = np.diag(Psi)
+
+    # # 添加正则化项以避免矩阵不可逆
+    # epsilon = 1e-6  # 小的正则化项
+    # Psi += epsilon * np.eye(Psi.shape[0])
+
+    # 更新量测噪声协方差
+    R_bar = R_inv_sqrt @ np.linalg.inv(Psi) @ R_inv_sqrt.T
+
+    # 计算跨协方差 P_xz 和自协方差 P_zz
+    P_xz = P_pred[0, 0] * z_hat[0]
+    P_zz = z_hat[0]**2 + R_bar[0, 0]
+
+
+    # 确保 P_zz 不为零以避免除以零错误
+    if P_zz < 1e-6:
+        P_zz = 1e-6
+
+    # 计算卡尔曼增益
+    K = P_xz / P_zz
+
+    # 状态更新
+    x_hat = x_pred + K * (z - z_hat[0])
+    P = P_pred - K**2 * P_xz
+
+    # 打印调试信息
+    print(f"量测: {z[0]}, 预测状态: {x_pred[0]}, 量测残差: {r[0]}, 标准化残差: {zeta[0]}, 权重: {psi[0]}, 量测噪声协方差: {R_bar[0, 0]}, 卡尔曼增益: {K}, 估计状态: {x_hat[0]}, 协方差: {P[0, 0]}")
+
+    return x_hat, P
+
+# 示例参数设置
+# 状态转移函数
+def f(x):
+    return x  # 简单的恒等转移
+
+# 量测函数
+def h(x):
+    return x  # 简单的恒等映射
+
+# 初始状态估计
+x_hat = np.array([0.1])
+# 初始协方差
+P = np.array([[1.0]])
+# 过程噪声协方差
+Q = np.array([[0.1]])
+# 量测噪声协方差
+R = np.array([[1.0]])
+
+# 模拟量测值,包含异常值
+measurements = [1.1, 1.2, 10.0, 1.3, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 10, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4]  # 10.0 为异常值
+
+# 运行滤波
+for z in measurements:
+    x_hat, P = resgqf_filter(x_hat, P, np.array([z]), f, h, Q, R)

+ 218 - 0
test_STEPS.py

@@ -0,0 +1,218 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from scipy.optimize import least_squares
+
+
+class STEPS:
+    """实现子信号技术参数估计算法(STEPS)"""
+
+    def __init__(self, fs, N0):
+        """
+        初始化STEPS估计器
+        :param fs: 采样频率
+        :param N0: 原始信号长度
+        """
+        self.fs = fs
+        self.N0 = N0
+        # 子信号长度选择为原始信号的2/3(论文推荐的最优比例)
+        self.N1 = int(2 * N0 / 3)
+        self.N2 = self.N1
+        # 子信号起始索引
+        self.n1 = 0
+        self.n2 = N0 - self.N2
+
+        # 确保子信号长度有效
+        if self.N1 <= 0 or self.N2 <= 0:
+            raise ValueError("子信号长度必须为正数,请增加原始信号长度")
+
+        # 预计算频率分辨率
+        self.freq_res = fs / N0
+
+    def extract_subsignals(self, x):
+        """
+        从原始信号中提取两个子信号
+        :param x: 原始信号
+        :return: 两个子信号x1和x2
+        """
+        if len(x) != self.N0:
+            raise ValueError(f"输入信号长度必须为{N0},实际为{len(x)}")
+
+        x1 = x[self.n1: self.n1 + self.N1]
+        x2 = x[self.n2: self.n2 + self.N2]
+        return x1, x2
+
+    def compute_dft(self, x):
+        """计算信号的DFT并返回幅度和相位"""
+        X = np.fft.fft(x)
+        amp = np.abs(X)
+        phase = np.angle(X)
+        return amp, phase
+
+    def find_peak_bin(self, amp):
+        """找到DFT幅度谱中的峰值位置"""
+        return np.argmax(amp)
+
+    def phase_correction(self, phase, k, N):
+        """
+        相位校正,消除线性相位影响
+        :param phase: 原始相位
+        :param k: 峰值所在的频点
+        :param N: 信号长度
+        :return: 校正后的相位
+        """
+        return phase - 2 * np.pi * k * (N - 1) / (2 * N)
+
+    def objective_function(self, delta1, k1, k2, c1, c2, phi1, phi2):
+        """用于求解非线性方程的目标函数"""
+        # 计算方程左边
+        left_numerator = c1 * np.tan(phi2) - c2 * np.tan(phi1)
+        left_denominator = c1 * c2 + np.tan(phi1) * np.tan(phi2)
+        left = left_numerator / left_denominator
+
+        # 计算方程右边
+        term = np.pi * (k1 + delta1) / self.N1 * (2 * (self.n2 - self.n1) + self.N2 - self.N1)
+        right = np.tan(term)
+
+        # 返回误差
+        return left - right
+
+    def estimate_frequency(self, x1, x2):
+        """
+        估计信号频率
+        :param x1, x2: 两个子信号
+        :return: 估计的频率
+        """
+        # 计算子信号的DFT
+        amp1, phase1 = self.compute_dft(x1)
+        amp2, phase2 = self.compute_dft(x2)
+
+        # 找到峰值位置
+        k1 = self.find_peak_bin(amp1)
+        k2 = self.find_peak_bin(amp2)
+
+        # 相位校正
+        phi1 = self.phase_correction(phase1[k1], k1, self.N1)
+        phi2 = self.phase_correction(phase2[k2], k2, self.N2)
+
+        # 计算c1和c2参数
+        c1 = np.sin(np.pi * k1 / self.N1) / np.sin(np.pi * (k1 + 1) / self.N1)
+        c2 = np.sin(np.pi * k2 / self.N2) / np.sin(np.pi * (k2 + 1) / self.N2)
+
+        # 求解非线性方程找到delta1
+        def func(delta):
+            return self.objective_function(delta, k1, k2, c1, c2, phi1, phi2)
+
+        # 使用最小二乘法求解
+        result = least_squares(func, x0=0.5, bounds=(0, 1))
+        delta1 = result.x[0]
+
+        # 计算频率
+        l0 = (self.N0 / self.N1) * (k1 + delta1)
+        f0 = (l0 / self.N0) * self.fs
+
+        return f0, k1, amp1[k1]
+
+    def estimate_amplitude(self, k1, amp1_peak):
+        """估计信号振幅"""
+        # 计算振幅校正因子
+        delta1 = (self.N1 / self.N0) * (self.fs / self.freq_res) - k1
+        correction = np.abs(np.sin(np.pi * delta1) / (self.N1 * np.sin(np.pi * (k1 + delta1) / self.N1)))
+        amplitude = amp1_peak * correction
+        return amplitude
+
+    def estimate_phase(self, f0, k1, amp1_peak):
+        """估计信号初始相位"""
+        delta1 = (self.N1 * f0 / self.fs) - k1
+        phase = np.angle(amp1_peak) - np.pi * delta1 * (self.N1 - 1) / self.N1
+        # 将相位归一化到[-π, π]范围
+        return (phase + np.pi) % (2 * np.pi) - np.pi
+
+    def estimate(self, x):
+        """
+        估计正弦信号的参数
+        :param x: 输入信号
+        :return: 频率、振幅和相位的估计值
+        """
+        # 提取子信号
+        x1, x2 = self.extract_subsignals(x)
+
+        # 估计频率
+        f0, k1, amp1_peak = self.estimate_frequency(x1, x2)
+
+        # 估计振幅
+        amp = self.estimate_amplitude(k1, amp1_peak)
+
+        # 估计相位
+        phase = self.estimate_phase(f0, k1, amp1_peak)
+
+        return f0, amp, phase
+
+
+# 演示如何使用STEPS类
+if __name__ == "__main__":
+    # 生成测试信号
+    fs = 1000  # 采样频率
+    N0 = 1024  # 信号长度
+    f_true = 75.3  # 真实频率
+    amp_true = 2.5  # 真实振幅
+    phase_true = np.pi / 3  # 真实相位
+    snr_db = 40  # 信噪比(dB)
+
+    # 生成时间向量
+    t = np.arange(N0) / fs
+
+    # 生成干净信号
+    x_clean = amp_true * np.sin(2 * np.pi * f_true * t + phase_true)
+
+    # 添加噪声
+    snr = 10 ** (snr_db / 10)
+    signal_power = np.sum(x_clean ** 2) / N0
+    noise_power = signal_power / snr
+    noise = np.sqrt(noise_power) * np.random.randn(N0)
+    x = x_clean + noise
+
+    # 创建STEPS估计器并进行参数估计
+    steps = STEPS(fs, N0)
+    f_est, amp_est, phase_est = steps.estimate(x)
+
+    # 计算误差
+    f_error = np.abs(f_est - f_true)
+    amp_error = np.abs(amp_est - amp_true)
+    phase_error = np.abs(phase_est - phase_true)
+
+    # 显示结果
+    print(f"真实频率: {f_true:.4f} Hz, 估计频率: {f_est:.4f} Hz, 误差: {f_error:.6f} Hz")
+    print(f"真实振幅: {amp_true:.4f}, 估计振幅: {amp_est:.4f}, 误差: {amp_error:.6f}")
+    print(f"真实相位: {phase_true:.4f} rad, 估计相位: {phase_est:.4f} rad, 误差: {phase_error:.6f} rad")
+
+    # 绘制信号和频谱
+    plt.figure(figsize=(12, 8))
+
+    # 设置中文字体支持
+    plt.rcParams["font.family"] = ["SimHei"]
+    plt.rcParams["axes.unicode_minus"] = False  # 解决负号显示问题
+
+    # 绘制时域信号
+    plt.subplot(2, 1, 1)
+    plt.plot(t, x_clean, label='干净信号')
+    plt.plot(t, x, alpha=0.7, label='带噪声信号')
+    plt.xlabel('时间 (s)')
+    plt.ylabel('振幅')
+    plt.title('时域信号')
+    plt.legend()
+
+    # 绘制频谱
+    plt.subplot(2, 1, 2)
+    freq = np.fft.fftfreq(N0, 1 / fs)
+    x_fft = np.fft.fft(x)
+    plt.plot(freq[:N0 // 2], 20 * np.log10(np.abs(x_fft[:N0 // 2])), label='信号频谱')
+    plt.xlabel('频率 (Hz)')
+    plt.ylabel('幅度 (dB)')
+    plt.title('信号频谱')
+    plt.xlim(0, 150)  # 限制频率范围以便观察
+    plt.axvline(f_true, color='r', linestyle='--', label=f'真实频率: {f_true} Hz')
+    plt.axvline(f_est, color='g', linestyle='--', label=f'估计频率: {f_est:.2f} Hz')
+    plt.legend()
+
+    plt.tight_layout()
+    plt.show()

+ 121 - 0
test_ai.py

@@ -0,0 +1,121 @@
+import math
+import tkinter as tk
+from tkinter import messagebox, ttk
+from decimal import Decimal, getcontext
+import threading
+
+# 定义 Chudnovsky 算法计算圆周率的函数
+def compute_pi_chudnovsky(iterations, progress_callback):
+    # 设置 Decimal 精度,确保足够的有效数字
+    getcontext().prec = 250  # 设置更高的精度,确保中间计算的准确性
+
+    # 初始化常量
+    C = Decimal(426880) * Decimal(10005).sqrt()
+    M = Decimal(1)
+    L = Decimal(13591409)
+    X = Decimal(-262537412640768000)  # 初始 X
+    K = Decimal(6)
+    S = Decimal(L)
+
+    total_iterations = iterations
+    for i in range(1, total_iterations + 1):  # 从 1 开始迭代
+        if i % 10 == 0:
+            progress_callback((i / total_iterations) * 100)
+
+        # 计算 M 和 L
+        M = (K**3 - 16 * K) * M // Decimal(i**3)
+        L += Decimal(545140134)
+
+        # 计算 S
+        try:
+            S += M * L / X
+        except ZeroDivisionError as e:
+            raise ValueError(f"除零错误在迭代次数: {i}, X: {X}, M: {M}, L: {L}") from e
+        except Exception as e:
+            raise ValueError(f"其他错误在迭代次数: {i}, X: {X}, M: {M}, L: {L}, S: {S}") from e
+
+        # 更新 X 和 K
+        X *= Decimal(-262537412640768000)
+        K += Decimal(12)
+
+    pi = C / S
+    return +pi  # 返回正数值
+
+# 定义工作线程函数,用于在后台计算圆周率
+def worker(iterations, progress_callback, result_callback):
+    try:
+        pi_computed = compute_pi_chudnovsky(iterations, progress_callback)
+        result_callback(pi_computed)
+    except Exception as e:
+        result_callback(None, str(e))
+
+# 定义计算按钮点击事件处理函数
+def on_calculate():
+    try:
+        iterations = int(entry_iterations.get())
+        if iterations < 1:
+            raise ValueError("迭代次数必须大于0")
+
+        # 创建进度窗口
+        progress_window = tk.Toplevel(root)
+        progress_window.title("计算进度")
+
+        # 初始化进度条变量
+        progress_var = tk.DoubleVar()
+        # 创建进度条控件
+        progress = ttk.Progressbar(progress_window, variable=progress_var, maximum=100)
+        progress.pack(pady=20, padx=20, fill='x')
+        # 创建进度标签控件
+        progress_label = tk.Label(progress_window, text="0%")
+        progress_label.pack(pady=10)
+
+        def progress_callback(progress):
+            progress_var.set(progress)
+            progress_label.config(text=f"{int(progress)}%")
+            progress_window.update_idletasks()
+
+        # 定义结果回调函数
+        def result_callback(pi_computed, error=None):
+            progress_window.destroy()
+            if error:
+                messagebox.showerror("错误", error)
+            else:
+                # 获取数学库中的 π 值(这里使用 Decimal 包装)
+                pi_math = Decimal(str(math.pi))
+
+                # 计算误差
+                difference = abs(pi_computed - pi_math)
+
+                # 格式化显示计算结果和数学库中的 π 值,并用科学计数法显示误差
+                pi_computed_str = f"{pi_computed:.200f}".rstrip('0').rstrip('.')
+                pi_math_str = f"{pi_math:.200f}".rstrip('0').rstrip('.')
+                difference_str = f"{difference:.20e}"
+
+                message = (f"计算得到的π: {pi_computed_str}\n"
+                           f"数学库中的π: {pi_math_str}\n"
+                           f"差异: {difference_str}")
+                messagebox.showinfo("结果", message)
+
+        # 创建并启动线程进行计算
+        thread = threading.Thread(target=worker, args=(iterations, progress_callback, result_callback))
+        thread.start()
+    except ValueError as e:
+        messagebox.showerror("错误", str(e))
+
+# 创建主窗口
+root = tk.Tk()
+root.title("计算圆周率")
+
+# 创建标签和输入框
+label_iterations = tk.Label(root, text="迭代次数:")
+label_iterations.grid(row=0, column=0, padx=10, pady=10)
+
+entry_iterations = tk.Entry(root)
+entry_iterations.grid(row=0, column=1, padx=10, pady=10)
+
+# 创建计算按钮
+button_calculate = tk.Button(root, text="计算", command=on_calculate)
+button_calculate.grid(row=1, column=0, columnspan=2, pady=10)
+
+# 运行主循环
+root.mainloop()

+ 24 - 0
test_gym.py

@@ -0,0 +1,24 @@
+import gym
+import time
+
+# 生成环境
+env = gym.make('CartPole-v1', render_mode='human')
+# 环境初始化
+state = env.reset(seed=1)
+# 循环交互
+
+while True:
+    # 渲染画面
+    # env.render()
+    # 从动作空间随机获取一个动作
+    action = env.action_space.sample()
+    # agent与环境进行一步交互
+    state, reward, terminated, truncated, info = env.step(action)
+    print('state = {0}; reward = {1}'.format(state, reward))
+    # 判断当前episode 是否完成
+    if terminated:
+        print('terminated')
+        break
+    time.sleep(0.1)
+# 环境结束
+# env.close()

+ 40 - 10
test_pytorch.py

@@ -5,20 +5,24 @@ import numpy as np
 import matplotlib.pyplot as plt
 
 
-print(torch.__version__)  # pytorch版本
-print(torch.version.cuda)  # cuda版本
+print(f"torch.__version__:{torch.__version__}")  # pytorch版本
+print(f"torch.version.cuda:{torch.version.cuda}")  # cuda版本
 print(torch.cuda.is_available())  # 查看cuda是否可用
 
 #
 # 使用GPU or CPU
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
-
+print(f"using device:{device}")
 
 # 生成一些随机的训练数据
 np.random.seed(42)
 x = np.random.rand(1000, 1)
-y = 2 * x + 1 + 0.5 * np.random.randn(1000, 1)
+y = 10 * x + 1 + 0.5 * np.random.randn(1000, 1)
+# y[900]=1000
+
+
+
 
 # 将数据转换为张量
 x_tensor = torch.from_numpy(x).float()
@@ -41,7 +45,9 @@ criterion = nn.MSELoss()
 optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
 
 # 训练模型
-num_epochs = 1000
+num_epochs = 6000
+dbug_epos=np.zeros(num_epochs)
+dbug_loss=np.zeros(num_epochs)
 for epoch in range(num_epochs):
     # 前向传播
     outputs = model(x_tensor)
@@ -52,6 +58,10 @@ for epoch in range(num_epochs):
     loss.backward()
     optimizer.step()
 
+    dbug_epos[epoch]=epoch
+    dbug_loss[epoch] = loss.item()
+
+
     if (epoch + 1) % 10 == 0:
         print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}')
 
@@ -60,9 +70,29 @@ with torch.no_grad():
     predicted = model(x_tensor)
 
 # 绘制原始数据和预测结果
-plt.scatter(x, y, label='Original Data')
-plt.plot(x, predicted.numpy(), color='red', label='Predicted Line')
-plt.xlabel('x')
-plt.ylabel('y')
-plt.legend()
+
+# plt.scatter(x, y, label='Original Data')
+# plt.plot(x, predicted.numpy(), color='red', label='Predicted Line')
+# plt.xlabel('x')
+# plt.ylabel('y')
+# plt.legend()
+# plt.show()
+
+fig1, ax1 = plt.subplots()
+ax1.scatter(x, y, label='Original Data')
+ax1.plot(x, predicted.numpy(), color='red', label='Predicted Line')
+ax1.set_xlabel('x')
+ax1.set_ylabel('y')
+ax1.legend()
+
+
+fig2, ax2 = plt.subplots()
+ax2.scatter(dbug_epos, dbug_loss, label='loss')
+ax2.plot(dbug_epos, dbug_loss, color='red', label='loss')
+ax2.set_xlabel('dbug_epos')
+ax2.set_ylabel('dbug_loss')
+ax2.legend()
+
+
+
 plt.show()

+ 365 - 0
test_pytorch_MNIST.py

@@ -0,0 +1,365 @@
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torchvision import datasets, transforms
+from torch.utils.data import DataLoader
+import numpy as np
+import tkinter as tk
+from PIL import Image, ImageDraw
+from tkinter import messagebox
+import cv2
+from torchsummary import summary
+from torchviz import make_dot
+import netron
+
+
+
+workmode = 1  #0:训练  1:加载模型
+
+
+# 数据预处理
+if  workmode==0:
+    transform = transforms.Compose([
+        # transforms.RandomRotation(10),  # 随机旋转 10 度
+        # transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 随机平移
+        transforms.ToTensor(),
+        transforms.Normalize((0.1307,), (0.3081,))
+    ])
+else:
+    transform = transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize((0.1307,), (0.3081,))
+    ])
+
+
+# 加载训练集和测试集
+train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
+test_dataset = datasets.MNIST('data', train=False, transform=transform)
+
+train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
+test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
+
+test_loader_predicted=np.zeros(len(test_loader.dataset))
+
+
+def showfig(images,labels,totalnum,num):
+    # 选择要显示的图片数量
+    num_images = num
+
+    # 创建一个子图布局
+    fig, axes = plt.subplots(1, num_images, figsize=(15, 3))
+
+    # 遍历数据集并显示图片和标签
+    for i in range(num_images):
+        # image, label = train_dataset[i]
+        random_numbers = np.random.choice(totalnum+1, num_images, replace=False)
+        # image = train_dataset.train_data[random_numbers[i]]
+        # label = train_dataset.train_labels[random_numbers[i]]
+        image = images[random_numbers[i]]
+        label = labels[random_numbers[i]]
+        # 将张量转换为numpy数组并调整维度
+        image = image.squeeze().numpy()
+        # 显示图片
+        axes[i].imshow(image, cmap='gray')
+        # 设置标题为标签
+        axes[i].set_title(f'idx-{random_numbers[i]}-Label: {label}')
+        axes[i].axis('off')
+    # # # 显示图形
+    # plt.show()
+
+
+# while 1:
+#     pass
+showfig(train_dataset.data,train_dataset.targets,60000,4)
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+# # 定义神经网络模型
+# class Net(nn.Module):
+#     def __init__(self):
+#         super(Net, self).__init__()
+#         self.fc1 = nn.Linear(784, 128)
+#         self.fc2 = nn.Linear(128, 64)
+#         self.fc3 = nn.Linear(64, 10)
+#
+#     def forward(self, x):
+#         x = x.view(-1, 784)
+#         x = torch.relu(self.fc1(x))
+#         x = torch.relu(self.fc2(x))
+#         x = self.fc3(x)
+#         return x
+# model = Net()
+# criterion = nn.CrossEntropyLoss()
+# optimizer = optim.SGD(model.parameters(), lr=0.01)
+
+
+
+# # 定义修改后的 AlexNet 模型,适应 MNIST 数据集
+# class AlexNet(nn.Module):
+#     def __init__(self, num_classes=10):
+#         super(AlexNet, self).__init__()
+#         self.features = nn.Sequential(
+#             nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 修改卷积核大小和通道数
+#             nn.ReLU(inplace=True),
+#             nn.MaxPool2d(kernel_size=2, stride=2),  # 修改池化核大小和步长
+#             nn.Conv2d(32, 64, kernel_size=3, padding=1),
+#             nn.ReLU(inplace=True),
+#             nn.MaxPool2d(kernel_size=2, stride=2),
+#             nn.Conv2d(64, 128, kernel_size=3, padding=1),
+#             nn.ReLU(inplace=True),
+#             nn.Conv2d(128, 128, kernel_size=3, padding=1),
+#             nn.ReLU(inplace=True),
+#             nn.Conv2d(128, 128, kernel_size=3, padding=1),
+#             nn.ReLU(inplace=True),
+#             nn.MaxPool2d(kernel_size=2, stride=2),
+#         )
+#         self.classifier = nn.Sequential(
+#             nn.Dropout(),
+#             nn.Linear(128 * 3 * 3, 128),  # 修改全连接层输入维度
+#             nn.ReLU(inplace=True),
+#             nn.Dropout(),
+#             nn.Linear(128, 128),
+#             nn.ReLU(inplace=True),
+#             nn.Linear(128, num_classes),
+#         )
+#
+#     def forward(self, x):
+#         x = self.features(x)
+#         x = x.view(x.size(0), 128 * 3 * 3)  # 修改展平后的维度
+#         x = self.classifier(x)
+#         return x
+#
+
+#
+# # 初始化模型、损失函数和优化器
+# model = AlexNet().to(device)
+
+
+# LeNet-5模型定义
+class LeNet5(nn.Module):
+    def __init__(self):
+        super(LeNet5, self).__init__()
+        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
+        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
+        self.fc1 = nn.Linear(16 * 4 * 4, 120)
+        self.fc2 = nn.Linear(120, 84)
+        self.fc3 = nn.Linear(84, 10)
+
+    def forward(self, x):
+        x = torch.relu(self.conv1(x))
+        x = nn.MaxPool2d(2)(x)
+        x = torch.relu(self.conv2(x))
+        x = nn.MaxPool2d(2)(x)
+        x = x.view(x.size(0), -1)
+        x = torch.relu(self.fc1(x))
+        x = torch.relu(self.fc2(x))
+        x = self.fc3(x)
+        return x
+
+# 实例化模型
+model = LeNet5()
+
+
+
+
+
+
+
+criterion = nn.CrossEntropyLoss()
+optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
+
+
+# # 打印模型结构
+# summary(model, input_size=(1, 28, 28))
+
+# # 创建一个随机输入张量
+# x = torch.randn(1, 1, 28, 28).to(device)
+# # 前向传播
+# y = model(x).to(device)
+# # 使用 torchviz 生成计算图
+# dot = make_dot(y, params=dict(model.named_parameters()))
+# # 保存计算图为图像文件(这里保存为 PNG 格式)
+# dot.render('alexnet_model', format='png', cleanup=True, view=True)
+
+# 训练过程
+def train(model, train_loader, optimizer, criterion, epoch):
+    model.train()
+    for batch_idx, (data, target) in enumerate(train_loader):
+        optimizer.zero_grad()
+        output = model(data)
+        loss = criterion(output, target)
+        loss.backward()
+        optimizer.step()
+        if batch_idx % 100 == 0:
+            print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
+                epoch, batch_idx * len(data), len(train_loader.dataset),
+                100. * batch_idx / len(train_loader), loss.item()))
+
+# 验证过程
+
+
+def test(model, test_loader):
+    model.eval()
+    correct = 0
+    total = 0
+    with torch.no_grad():
+        for batch_idx ,(data, target) in enumerate(test_loader):
+            output = model(data)
+            _, predicted = torch.max(output.data, 1)
+            test_loader_predicted[batch_idx*64+0:batch_idx*64+64]=predicted.numpy()
+            total += target.size(0)
+            correct += (predicted == target).sum().item()
+
+    accuracy = 100 * correct / total
+    print('Test Accuracy: {:.2f}%'.format(accuracy))
+
+# # 训练和验证模型
+
+if workmode==0:  #0:训练
+    for epoch in range(10):
+        train(model, train_loader, optimizer, criterion, epoch)
+        test(model, test_loader)
+
+    torch.save(model, 'model.pth')
+    print(f'save model as model.pth')
+else:            #1:加载模型
+    print(f'load model.pth')
+    try:
+        model = torch.load('model.pth', weights_only=False)
+        model.eval()
+    except Exception as e:
+        print(f"加载模型时出现错误: {e}")
+    test(model, test_loader)
+
+    # netron.start('model.pth')  # 输出网络结构图
+
+showfig(test_loader.dataset.data, test_loader_predicted, 10000, 4)
+plt.show()
+
+
+def save_drawing():
+    global drawing_points
+    # 创建一个空白图像
+    image = Image.new("RGB", (canvas.winfo_width(), canvas.winfo_height()), "black")
+    draw = ImageDraw.Draw(image)
+
+    # 绘制线条
+    for i in range(1, len(drawing_points)):
+        x1,y1=drawing_points[i - 1]
+        x2, y2 = drawing_points[i]
+        if (x1 is not None) and (x2 is not None) and (y1 is not None) and (y2 is not None):
+            draw.line((x1, y1, x2, y2), fill="white", width=20)
+
+    image = image.convert('L')
+    image1=image.resize((28,28))
+    # # 4. 转换为 numpy 数组
+    # image_array = np.array(image1)
+    #
+    # # 5. 二值化
+    # _, binary_image = cv2.threshold(image_array, 127, 255, cv2.THRESH_BINARY)
+    #
+    # # 6. 居中处理
+    # rows, cols = binary_image.shape
+    # M = cv2.moments(binary_image)
+    # if M["m00"] != 0:
+    #     cX = int(M["m10"] / M["m00"])
+    #     cY = int(M["m01"] / M["m00"])
+    # else:
+    #     cX, cY = 0, 0
+    # shift_x = cols / 2 - cX
+    # shift_y = rows / 2 - cY
+    # M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
+    # centered_image = cv2.warpAffine(binary_image, M, (cols, rows))
+    #
+    # # 7. 归一化
+    # normalized_image = centered_image / 255.0
+    #
+    # # 8. 调整维度以适应模型输入
+    # final_image = normalized_image.reshape(28, 28)
+    # image1 = Image.fromarray(final_image)
+
+
+
+    # # 转换为numpy数组
+    # img_array = np.array(image1)
+    # # 中值滤波
+    # filtered_img = cv2.medianBlur(img_array, 3)
+    # # 转换回Image对象(如果需要的话)
+    # image1 = Image.fromarray(filtered_img)
+
+    tensor_image = transform(image1) #torch.Size([3, 28, 28])
+
+    # gray_tensor = torch.mean(tensor_image, dim=0, keepdim=True)
+    # pool = torch.nn.MaxPool2d(kernel_size=10, stride=10)
+    # pooled_image = pool(gray_tensor.unsqueeze(0)).squeeze(0)
+
+    # pooled_image = gray_tensor
+    pooled_image = tensor_image.unsqueeze(0)
+
+
+    # print(f'tensor_image :{gray_tensor.shape} -pooled_image:{pooled_image.shape}')
+
+    # simage=pooled_image.view(28,28)
+    # simage = (simage - simage.min()) / (simage.max() - simage.min())
+    # np_array = (simage.numpy() * 255).astype('uint8')
+    # image_f = Image.fromarray(np_array)
+    # image_f.show()
+
+
+
+    with torch.no_grad():
+        output = torch.softmax(model(pooled_image),1)
+        # print(f'output.data={output}')
+        v, predicted = torch.max(output, 1)
+
+    print(f'预测数字={predicted.numpy()[0]},概率:{(v*100).numpy()[0]:.2f}%')
+    messagebox.showinfo('识别结果',f'predicted={predicted.numpy()[0]}')
+    drawing_points=[]
+    canvas.delete("all")
+
+    # 保存图像
+    image.save("drawing.png")
+    image1.save("drawing28x28.png")
+    # print("绘画已保存为 drawing.png")
+
+last_x,last_y=None,None
+# last_y=[]
+def on_mouse_move(event):
+    global last_x,last_y
+
+    # drawing_points.append((event.x, event.y))
+    drawing_points.append((last_x, last_y))
+    if (last_x is not None) and  (last_y is not None) :
+        canvas.create_line(last_x, last_y, event.x, event.y, fill="white", width=20, smooth=True, splinesteps=10)
+        # canvas.create_line(last_x, last_y , event.x, event.y, fill="white", width=20)
+
+    last_x, last_y = event.x, event.y
+
+
+def on_mouse_release(event):
+    global last_x, last_y
+    last_x, last_y = None, None
+    # print("on_mouse_release")
+    pass
+
+
+
+
+
+root = tk.Tk()
+
+canvas = tk.Canvas(root, width=280*2, height=280*2, bg="black")
+canvas.pack()
+
+# canvas_show = tk.Canvas(root, width=280, height=280, bg="black")
+# canvas_show.pack()
+
+button = tk.Button(root, text="识别", command=save_drawing)
+button.pack()
+
+drawing_points = []
+canvas.bind("<B1-Motion>", on_mouse_move)
+canvas.bind("<ButtonRelease-1>", on_mouse_release)
+
+root.mainloop()

+ 2 - 1
test_pytorch_cuda.py

@@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
 
 # 检查是否有可用的GPU
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+print(f"using device:{device}")
 
 # 定义数据预处理步骤
 transform = transforms.Compose([
@@ -49,7 +50,7 @@ optimizer = optim.SGD(model.parameters(), lr=0.01)
 # 训练模型
 def train_model():
     model.train()
-    for epoch in range(20):
+    for epoch in range(5):  # should be 20
         running_loss = 0.0
         for batch_idx, (data, target) in enumerate(train_loader):
             data, target = data.to(device), target.to(device)

+ 112 - 0
test_pytorch_lstm.py

@@ -0,0 +1,112 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import matplotlib.pyplot as plt
+
+# 设置随机种子以确保结果可重复
+np.random.seed(0)
+torch.manual_seed(0)
+
+# 1. 生成随机一位序列
+sequence_length = 100
+# random_sequence = np.random.randint(0, 200, size=sequence_length)
+random_sequence = np.linspace(0, 200, sequence_length)
+# 将序列保存到文件
+np.savetxt('random_sequence.csv', random_sequence, delimiter=',')
+
+# 2. 定义 LSTM 模型
+class LSTMModel(nn.Module):
+    def __init__(self, input_size, hidden_size, num_layers, output_size):
+        super(LSTMModel, self).__init__()
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
+        self.fc = nn.Linear(hidden_size, output_size)
+
+    def forward(self, x):
+        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
+        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
+        out, _ = self.lstm(x, (h0, c0))
+        out = self.fc(out[:, -1, :])
+        return out
+
+# 3. 准备数据
+def create_sequences(data, seq_length):
+    xs, ys = [], []
+    for i in range(len(data) - seq_length):
+        x = data[i:i+seq_length]
+        y = data[i+seq_length]
+        xs.append(x)
+        ys.append(y)
+    return np.array(xs), np.array(ys)
+
+# 参数设置
+input_size = 1
+hidden_size = 32
+num_layers = 2
+output_size = 1
+seq_length = 10
+
+# 创建序列
+X, y = create_sequences(random_sequence, seq_length)
+
+# 转换为 PyTorch 张量
+X = torch.tensor(X, dtype=torch.float32).unsqueeze(2)
+y = torch.tensor(y, dtype=torch.float32).unsqueeze(1)
+
+# 4. 训练 LSTM 模型
+model = LSTMModel(input_size, hidden_size, num_layers, output_size)
+criterion = nn.MSELoss()
+optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
+
+num_epochs = 100*100
+for epoch in range(num_epochs):
+    model.train()
+    outputs = model(X)
+    loss = criterion(outputs, y)
+
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+
+    if (epoch+1) % 10 == 0:
+        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
+
+# 5. 生成预测结果
+model.eval()
+with torch.no_grad():
+    # 初始化 test_input 为前 seq_length 个序列
+    test_input = X[0].unsqueeze(0)  # 形状 (1, 5, 1)
+    predicted = []
+    for _ in range(sequence_length - seq_length):
+        output = model(test_input)
+        predicted.append(output.item())
+        # 更新 test_input: 移除第一个时间步并添加预测的输出
+        test_input = torch.cat((test_input[:, 1:], output.view(1, 1, 1)), dim=1)
+
+    predicted = np.array(predicted)
+
+np.savetxt('random_sequence_predicted.csv', predicted, delimiter=',')
+
+# 6. 使用 matplotlib 绘制对比图
+plt.figure(figsize=(12, 6))
+
+# 绘制原始序列
+plt.subplot(1, 2, 1)
+plt.scatter(range(sequence_length), random_sequence, label='Original Sequence', color='blue', s=50)
+plt.title('原始序列')
+plt.xlabel('时间步')
+plt.ylabel('值')
+plt.legend()
+
+# 绘制预测序列
+plt.subplot(1, 2, 2)
+plt.scatter(range(seq_length, sequence_length), predicted, label='Predicted Sequence', color='red', s=50)
+plt.title('预测序列')
+plt.xlabel('时间步')
+plt.ylabel('值')
+plt.legend()
+
+plt.tight_layout()
+plt.show()

+ 43 - 0
test_supervision.py

@@ -0,0 +1,43 @@
+import supervision as sv
+import cv2
+import os
+from inference import get_model
+from ultralytics import YOLO
+
+
+print(sv.__version__)
+
+HOME = os.getcwd()
+# print(HOME)
+IMAGE_PATH = f"{HOME}//images//dog-3.jpeg"
+image = cv2.imread(IMAGE_PATH)
+
+# # 修正 cv2.imshow() 函数调用,添加窗口名称
+# cv2.imshow('Display Image', image)
+# # 等待按键事件,防止窗口立即关闭
+# cv2.waitKey(0)
+# # 关闭所有 OpenCV 窗口
+# cv2.destroyAllWindows()
+
+# model = get_model(model_id="yolov8s-640")
+# result = model.infer(image)[0]
+# detections = sv.Detections.from_inference(result)
+
+
+
+model = YOLO("yolov8s.pt")
+result = model(image, verbose=False)[0]
+detections = sv.Detections.from_ultralytics(result)
+
+# print(detections)
+box_annotator = sv.BoxAnnotator()
+label_annotator = sv.LabelAnnotator()
+
+annotated_image = image.copy()
+annotated_image = box_annotator.annotate(annotated_image, detections=detections)
+annotated_image = label_annotator.annotate(annotated_image, detections=detections)
+
+sv.plot_image(image=annotated_image, size=(8, 8))
+
+print('end of file')
+

+ 20 - 0
test_supervision_videos.py

@@ -0,0 +1,20 @@
+from supervision.assets import download_assets, VideoAssets
+import supervision as sv
+import cv2
+import os
+from inference import get_model
+from ultralytics import YOLO
+
+print(sv.__version__)
+HOME = os.getcwd()
+
+download_assets(VideoAssets.VEHICLES)
+VIDEO_PATH = VideoAssets.VEHICLES.value
+
+sv.VideoInfo.from_video_path(video_path=VIDEO_PATH)
+frame_generator = sv.get_video_frames_generator(source_path=VIDEO_PATH)
+frame = next(iter(frame_generator))
+sv.plot_image(image=frame, size=(4, 4))
+
+RESULT_VIDEO_PATH = f"{HOME}/videos/vehicle-counting-result.mp4"
+

+ 95 - 0
test_web_auto.py

@@ -0,0 +1,95 @@
+#
+from selenium import webdriver
+from selenium.webdriver.chrome.service import Service
+from selenium.webdriver.chrome.options import Options
+from selenium.webdriver.common.by import By
+from selenium.webdriver.support.ui import WebDriverWait
+from selenium.webdriver.support import expected_conditions as EC
+import time
+
+
+def traverse_iframes(driver, iframe_list=None):
+    """
+    递归遍历网页中的所有iframe
+
+    :param driver: Selenium的WebDriver实例
+    :param iframe_list: 用于存储已遍历的iframe元素的列表,默认为None
+    :return: 包含所有iframe元素的列表
+    """
+    if iframe_list is None:
+        iframe_list = []
+    # 查找当前页面下所有的iframe元素
+    iframes = driver.find_elements("tag name", "iframe")
+    for iframe in iframes:
+        iframe_list.append(iframe)
+        # 切换到当前iframe
+        driver.switch_to.frame(iframe)
+        print(iframe.id)
+        e = driver.find_elements(By.ID, 'switcher_plogin')
+        if e is None:
+            # e[0].click()
+            print(f'e={e}')
+        # 递归调用,继续查找当前iframe内部的iframe元素
+        traverse_iframes(driver, iframe_list)
+        # 切换回父级(上一层),这样才能继续查找同一层级的其他iframe
+        driver.switch_to.parent_frame()
+    return iframe_list
+
+# # 配置 ChromeDriver 路径,替换为你的 ChromeDriver 路径,你也可以将chromedriver拖入文件根目录,使用'./chromedriver.exe'路径。
+# chrome_driver_path = 'D:/JIAL/JIALConfig/chromedriver/chromedriver.exe'  # 替换为你的 ChromeDriver 路径
+#
+# # 初始化 ChromeDriver Service
+# service = Service(chrome_driver_path)
+# # 打开浏览器时的相关配置,可以根据需求进行打开和关闭
+# options = Options()
+# options.add_argument("--start-maximized")  # 启动时最大化窗口
+# # options.add_argument("--disable-blink-features=AutomationControlled")  # 使浏览器不显示自动化控制的信息
+# # options.add_argument("--disable-gpu")  # 禁用GPU硬件加速
+# # options.add_argument("--disable-infobars")  # 隐藏信息栏
+# # options.add_argument("--disable-extensions")  # 禁用所有扩展程序
+# # options.add_argument("--disable-popup-blocking")  # 禁用弹出窗口拦截
+# # options.add_argument("--incognito")  # 启动无痕模式
+# # options.add_argument("--no-sandbox")  # 关闭沙盒模式(提高性能)
+# # options.add_argument("--disable-dev-shm-usage")  # 使用/dev/shm分区以避免共享内存问题
+# # options.add_argument("--remote-debugging-port=9222")  # 启用远程调试端口
+# # 初始化 WebDriver,并传入 ChromeDriver Service
+driver = webdriver.Chrome()
+
+# 打开百度搜索首页
+driver.get("https://mail.qq.com/")
+#
+# 打印页面标题
+print(driver.title)
+# 延时5秒钟,也就是浏览器打开5秒钟,避免闪退
+time.sleep(1)
+
+# def find_all_iframes(idriver):
+#     iframes = idriver.find_elements('tag name','iframe')
+#     for index, iframe in enumerate(iframes):
+#         # Your sweet business logic applied to iframe goes here.
+#         # print(iframe.id)
+#         driver.switch_to.frame(index)
+#         e = driver.find_elements(By.ID, 'switcher_plogin')
+#         print(e)
+#         find_all_iframes(idriver)
+#         driver.switch_to.parent_frame()
+
+# find_all_iframes(driver)
+all_iframes  = traverse_iframes(driver)
+# print(all_iframes)
+# print(driver.page_source)
+# driver.find_element_by_id('j_username').send_keys('***')
+# e=driver.find_element('class name','switcher_plogin')
+# driver.switch_to.frame('ptlogin_iframe')
+
+# e=driver.find_elements('tag name','iframe')
+# # e=driver.find_element(By.ID,'switcher_plogin')
+# print(e)
+#
+# driver.switch_to.frame(1)
+# e=driver.find_elements('tag name','iframe')
+# # e=driver.find_element(By.ID,'switcher_plogin')
+# print(e)
+time.sleep(1)
+# 关闭 WebDriver
+driver.quit()