diff --git a/converter/main.py b/converter/main.py index 32c36292af1dedb6f4f598d05f27e526b7bf3967..148900b6278b514213b1ba50c577f5d9a8da64ef 100644 --- a/converter/main.py +++ b/converter/main.py @@ -362,7 +362,10 @@ def extract_additional_data(name, to_transpose, onnx_graph, verbose): return extract_additional_data_from_node(init, to_transpose) for node in onnx_graph.node: # not found in initializaer, search in Constant if name == node.output[0]: - return extract_additional_data_from_node(node.attribute[0].t, to_transpose) + if node.op_type == "Identity": + return extract_additional_data(node.input[0], to_transpose, onnx_graph, verbose) + else: + return extract_additional_data_from_node(node.attribute[0].t, to_transpose) quit("[ERROR] unable to extract data in {}".format(name)) @@ -414,7 +417,10 @@ def getNodesWithOutputNotConst(name, model): for node in model.graph.node: for out in node.output: if out == name: - return node + if node.op_type == "Identity": + return getNodesWithOutputNotConst(node.input[0], model) + else: + return node for node in model.graph.input: if node.name == name: return node @@ -499,6 +505,10 @@ def add_transpose_to_output(node, myGraph, map_onnx_to_myGraph): def parse_graph_node( node, model_onnx, myGraph, node_annotation, map_onnx_to_myGraph, verbose ): + if node.name not in node_annotation: + print(f"[WARN] node {node.name} never reached: removed") + return + if verbose > 1: print(f"parse node {node.name} remove={node_annotation[node.name].to_remove}") @@ -1822,6 +1832,8 @@ def annotate_node( for inp in node.input: n2 = getNodesWithOutputNotConst(inp, model_onnx) if n2 is not None: + if verbose: + print(f"[INFO] input: {n2.name}") if n2.name in node_annotation: if data_layout is None: data_layout = node_annotation[n2.name].layout_onnx @@ -1836,6 +1848,8 @@ def annotate_node( # print("[INFO] annotations:\n{" + "\n".join("{!r}: {!r},".format(k, v) for k, v in node_annotation.items())+ "}") # quit() else: # not ready yet + if verbose > 2: + print(f"[INFO] input not ready: {n2.name}, abord") return if verbose > 1 and data_layout is None: @@ -2037,7 +2051,7 @@ def annotate_graph(model_onnx, node_annotation, data_layout, verbose): for inp in model_onnx.graph.node: if inp.op_type == "Constant": node_annotation[inp.name] = Node_Annotation() - node_annotation[inp.name].layout_onnx = None + node_annotation[inp.name].layout_onnx = None for inp in model_onnx.graph.input: nexts = getNodesWithInput(inp.name, model_onnx)