Skip to content

Commit

Permalink
🧪 test(增加测试): 测试onnxsim相关代码
Browse files Browse the repository at this point in the history
  • Loading branch information
bruce1408 committed Apr 11, 2024
1 parent f8804c1 commit dcd0c78
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
25 changes: 25 additions & 0 deletions Tools/change_onnx_output_type.py
@@ -0,0 +1,25 @@
import onnx
from onnx import helper
from printk import print_colored_box

# 加载 ONNX 模型
model_path = '/mnt/share_disk/bruce_cui/onnx_models/laneline_154w_20240320_fastbev_onnxsim.onnx'
model = onnx.load(model_path)

# 假设我们要修改的输出是模型的最后一个输出
# 首先,获取输出的数量
num_outputs = len(model.graph.output)
print(num_outputs)

for output in model.graph.output:

if output.type.tensor_type.elem_type == onnx.TensorProto.INT64:
# 修改数据类型为 float32
output.type.tensor_type.elem_type = onnx.TensorProto.FLOAT

# 保存修改后的模型
modified_model_path = '/mnt/share_disk/bruce_cui/onnx_models/laneline_154w_20240320_fastbev_onnxsim_output_float32.onnx'
onnx.save(model, modified_model_path)

print_colored_box("模型输出类型已从 int64 修改为 float32。")

5 changes: 4 additions & 1 deletion Tools/onnx_model_convert_onnxSimplify.py
Expand Up @@ -56,7 +56,10 @@ def print_onnx_input_output(model_path):

if __name__=="__main__":
# model_path = "/Users/bruce/Downloads/8620_deploy/swin_tiny_patch4_window7_224_224_elementwise_affine.onnx"
model_path = "/Users/bruce/Downloads/obstacle_v1.0.onnx"
# model_path = "/Users/bruce/Downloads/obstacle_v1.0.onnx"
# model_path = "/mnt/share_disk/bruce_cui/onnx_models/laneline_154w_20240320_fastbev.onnx"
model_path = "/mnt/share_disk/bruce_cui/onnx_models/laneline_20240330_fastbev_wo_argmax.onnx"
# model_path = "/Users/bruce/Downloads/8620_deploy/Laneline/models/epoch_latest_0302.onnx"

onnx_simplify(model_path)
# print_onnx_input_output(model_path)
Expand Down

0 comments on commit dcd0c78

Please sign in to comment.