| | |
| | |
| |
|
| | |
| | import argparse |
| | import logging |
| |
|
| | import numpy as np |
| | import onnx |
| | import sympy |
| | from onnx import helper, numpy_helper, shape_inference |
| | from packaging import version |
| |
|
| | assert version.parse(onnx.__version__) >= version.parse("1.8.0") |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def get_attribute(node, attr_name, default_value=None): |
| | found = [attr for attr in node.attribute if attr.name == attr_name] |
| | if found: |
| | return helper.get_attribute_value(found[0]) |
| | return default_value |
| |
|
| |
|
| | def get_dim_from_proto(dim): |
| | return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None |
| |
|
| |
|
| | def is_sequence(type_proto): |
| | cls_type = type_proto.WhichOneof("value") |
| | assert cls_type in ["tensor_type", "sequence_type"] |
| | return cls_type == "sequence_type" |
| |
|
| |
|
| | def get_shape_from_type_proto(type_proto): |
| | assert not is_sequence(type_proto) |
| | if type_proto.tensor_type.HasField("shape"): |
| | return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim] |
| | else: |
| | return None |
| |
|
| |
|
| | def get_elem_type_from_type_proto(type_proto): |
| | if is_sequence(type_proto): |
| | return type_proto.sequence_type.elem_type.tensor_type.elem_type |
| | else: |
| | return type_proto.tensor_type.elem_type |
| |
|
| |
|
| | def get_shape_from_value_info(vi): |
| | cls_type = vi.type.WhichOneof("value") |
| | if cls_type is None: |
| | return None |
| | if is_sequence(vi.type): |
| | if vi.type.sequence_type.elem_type.WhichOneof("value") == "tensor_type": |
| | return get_shape_from_type_proto(vi.type.sequence_type.elem_type) |
| | else: |
| | return None |
| | else: |
| | return get_shape_from_type_proto(vi.type) |
| |
|
| |
|
| | def make_named_value_info(name): |
| | vi = onnx.ValueInfoProto() |
| | vi.name = name |
| | return vi |
| |
|
| |
|
| | def get_shape_from_sympy_shape(sympy_shape): |
| | return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape] |
| |
|
| |
|
| | def is_literal(dim): |
| | return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number) |
| |
|
| |
|
| | def handle_negative_axis(axis, rank): |
| | assert axis < rank and axis >= -rank |
| | return axis if axis >= 0 else rank + axis |
| |
|
| |
|
| | def get_opset(mp, domain=None): |
| | domain = domain or ["", "onnx", "ai.onnx"] |
| | if type(domain) != list: |
| | domain = [domain] |
| | for opset in mp.opset_import: |
| | if opset.domain in domain: |
| | return opset.version |
| |
|
| | return None |
| |
|
| |
|
| | def as_scalar(x): |
| | if type(x) == list: |
| | assert len(x) == 1 |
| | return x[0] |
| | elif type(x) == np.ndarray: |
| | return x.item() |
| | else: |
| | return x |
| |
|
| |
|
| | def as_list(x, keep_none): |
| | if type(x) == list: |
| | return x |
| | elif type(x) == np.ndarray: |
| | return list(x) |
| | elif keep_none and x is None: |
| | return None |
| | else: |
| | return [x] |
| |
|
| |
|
| | def sympy_reduce_product(x): |
| | if type(x) == list: |
| | value = sympy.Integer(1) |
| | for v in x: |
| | value = value * v |
| | else: |
| | value = x |
| | return value |
| |
|
| |
|
| | class SymbolicShapeInference: |
| | def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): |
| | self.dispatcher_ = { |
| | "Add": self._infer_symbolic_compute_ops, |
| | "ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor, |
| | "AveragePool": self._infer_Pool, |
| | "BatchNormalization": self._infer_BatchNormalization, |
| | "Cast": self._infer_Cast, |
| | "CategoryMapper": self._infer_CategoryMapper, |
| | "Compress": self._infer_Compress, |
| | "Concat": self._infer_Concat, |
| | "ConcatFromSequence": self._infer_ConcatFromSequence, |
| | "Constant": self._infer_Constant, |
| | "ConstantOfShape": self._infer_ConstantOfShape, |
| | "Conv": self._infer_Conv, |
| | "CumSum": self._pass_on_shape_and_type, |
| | "Div": self._infer_symbolic_compute_ops, |
| | "Einsum": self._infer_Einsum, |
| | "Expand": self._infer_Expand, |
| | "Equal": self._infer_symbolic_compute_ops, |
| | "Floor": self._infer_symbolic_compute_ops, |
| | "Gather": self._infer_Gather, |
| | "GatherElements": self._infer_GatherElements, |
| | "GatherND": self._infer_GatherND, |
| | "Identity": self._pass_on_shape_and_type, |
| | "AllReduce": self._pass_on_shape_and_type, |
| | "If": self._infer_If, |
| | "Loop": self._infer_Loop, |
| | "MatMul": self._infer_MatMul, |
| | "MatMulInteger16": self._infer_MatMulInteger, |
| | "MaxPool": self._infer_Pool, |
| | "Max": self._infer_symbolic_compute_ops, |
| | "MemcpyFromHost": self._pass_on_shape_and_type, |
| | "MemcpyToHost": self._pass_on_shape_and_type, |
| | "Min": self._infer_symbolic_compute_ops, |
| | "MoE": self._pass_on_shape_and_type, |
| | "Mul": self._infer_symbolic_compute_ops, |
| | "NonMaxSuppression": self._infer_NonMaxSuppression, |
| | "NonZero": self._infer_NonZero, |
| | "OneHot": self._infer_OneHot, |
| | "Pad": self._infer_Pad, |
| | "Range": self._infer_Range, |
| | "Reciprocal": self._pass_on_shape_and_type, |
| | "ReduceSum": self._infer_ReduceSum, |
| | "ReduceProd": self._infer_ReduceProd, |
| | "Reshape": self._infer_Reshape, |
| | "Resize": self._infer_Resize, |
| | "Round": self._pass_on_shape_and_type, |
| | "Scan": self._infer_Scan, |
| | "ScatterElements": self._infer_ScatterElements, |
| | "SequenceAt": self._infer_SequenceAt, |
| | "SequenceInsert": self._infer_SequenceInsert, |
| | "Shape": self._infer_Shape, |
| | "Size": self._infer_Size, |
| | "Slice": self._infer_Slice, |
| | "SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss, |
| | "SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss, |
| | "NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss, |
| | "Split": self._infer_Split, |
| | "SplitToSequence": self._infer_SplitToSequence, |
| | "Squeeze": self._infer_Squeeze, |
| | "Sub": self._infer_symbolic_compute_ops, |
| | "Tile": self._infer_Tile, |
| | "TopK": self._infer_TopK, |
| | "Transpose": self._infer_Transpose, |
| | "Unsqueeze": self._infer_Unsqueeze, |
| | "Where": self._infer_symbolic_compute_ops, |
| | "ZipMap": self._infer_ZipMap, |
| | "Neg": self._infer_symbolic_compute_ops, |
| | |
| | "Attention": self._infer_Attention, |
| | "BiasAdd": self._infer_BiasAdd, |
| | "BiasGelu": self._infer_BiasGelu, |
| | "BiasSplitGelu": self._infer_BiasSplitGelu, |
| | "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention, |
| | "DequantizeLinear": self._infer_DequantizeLinear, |
| | "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, |
| | "FastGelu": self._infer_FastGelu, |
| | "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, |
| | "Gelu": self._infer_Gelu, |
| | "GemmFastGelu": self._infer_GemmFastGelu, |
| | "GemmFloat8": self._infer_GemmFloat8, |
| | "GroupNorm": self._infer_GroupNorm, |
| | "GroupQueryAttention": self._infer_GroupQueryAttention, |
| | "SkipGroupNorm": self._infer_SkipGroupNorm, |
| | "LayerNormalization": self._infer_LayerNormalization, |
| | "LongformerAttention": self._infer_LongformerAttention, |
| | "MultiHeadAttention": self._infer_MultiHeadAttention, |
| | "NhwcConv": self._infer_NhwcConv, |
| | "PackedAttention": self._infer_PackedAttention, |
| | "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, |
| | "PagedAttention": self._infer_PagedAttention, |
| | "PythonOp": self._infer_PythonOp, |
| | "QuantizeLinear": self._infer_QuantizeLinear, |
| | "QuickGelu": self._infer_FastGelu, |
| | "RelativePositionBias": self._infer_RelativePositionBias, |
| | "RemovePadding": self._infer_RemovePadding, |
| | "RestorePadding": self._infer_RestorePadding, |
| | "RotaryEmbedding": self._infer_RotaryEmbedding, |
| | "SimplifiedLayerNormalization": self._infer_LayerNormalization, |
| | "SkipLayerNormalization": self._infer_SkipLayerNormalization, |
| | "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, |
| | } |
| | self.aten_op_dispatcher_ = { |
| | "embedding": self._infer_Gather, |
| | "bitwise_or": self._infer_aten_bitwise_or, |
| | "diagonal": self._infer_aten_diagonal, |
| | "max_pool2d_with_indices": self._infer_aten_pool2d, |
| | "max": self._infer_aten_minmax, |
| | "min": self._infer_aten_minmax, |
| | "multinomial": self._infer_aten_multinomial, |
| | "unfold": self._infer_aten_unfold, |
| | "argmax": self._infer_aten_argmax, |
| | "avg_pool2d": self._infer_aten_pool2d, |
| | "_adaptive_avg_pool2d": self._infer_aten_pool2d, |
| | "numpy_T": self._infer_Transpose, |
| | "native_group_norm": self._infer_aten_group_norm, |
| | "upsample_nearest1d": self._infer_aten_upsample, |
| | "upsample_nearest2d": self._infer_aten_upsample, |
| | "upsample_nearest3d": self._infer_aten_upsample, |
| | "upsample_bicubic2d": self._infer_aten_upsample, |
| | } |
| | self.run_ = True |
| | self.suggested_merge_ = {} |
| | self.symbolic_dims_ = {} |
| | self.input_symbols_ = {} |
| | self.auto_merge_ = auto_merge |
| | self.guess_output_rank_ = guess_output_rank |
| | self.verbose_ = verbose |
| | self.int_max_ = int_max |
| | self.subgraph_id_ = 0 |
| | self.prefix_ = prefix |
| |
|
| | def _add_suggested_merge(self, symbols, apply=False): |
| | assert all([(type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols]) |
| | symbols = set(symbols) |
| | for k, v in self.suggested_merge_.items(): |
| | if k in symbols: |
| | symbols.remove(k) |
| | symbols.add(v) |
| | map_to = None |
| | |
| | for s in symbols: |
| | if is_literal(s): |
| | map_to = s |
| | break |
| | |
| | if map_to is None: |
| | for s in symbols: |
| | if s in self.input_symbols_: |
| | map_to = s |
| | break |
| | if map_to is None: |
| | for s in symbols: |
| | if type(self.symbolic_dims_[s]) == sympy.Symbol: |
| | map_to = s |
| | break |
| | |
| | if map_to is None: |
| | if self.verbose_ > 0: |
| | logger.warning("Potential unsafe merge between symbolic expressions: ({})".format(",".join(symbols))) |
| | symbols_list = list(symbols) |
| | lens = [len(s) for s in symbols_list] |
| | map_to = symbols_list[lens.index(min(lens))] |
| | symbols.remove(map_to) |
| |
|
| | for s in symbols: |
| | if s == map_to: |
| | continue |
| | if is_literal(map_to) and is_literal(s): |
| | assert int(map_to) == int(s) |
| | self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to |
| | for k, v in self.suggested_merge_.items(): |
| | if v == s: |
| | self.suggested_merge_[k] = map_to |
| | if apply and self.auto_merge_: |
| | self._apply_suggested_merge() |
| |
|
| | def _apply_suggested_merge(self, graph_input_only=False): |
| | if not self.suggested_merge_: |
| | return |
| | for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)): |
| | for d in i.type.tensor_type.shape.dim: |
| | if d.dim_param in self.suggested_merge_: |
| | v = self.suggested_merge_[d.dim_param] |
| | if is_literal(v): |
| | d.dim_value = int(v) |
| | else: |
| | d.dim_param = v |
| |
|
| | def _preprocess(self, in_mp): |
| | self.out_mp_ = onnx.ModelProto() |
| | self.out_mp_.CopyFrom(in_mp) |
| | self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)} |
| | self.initializers_ = {i.name: i for i in self.out_mp_.graph.initializer} |
| | self.known_vi_ = {i.name: i for i in list(self.out_mp_.graph.input)} |
| | self.known_vi_.update( |
| | { |
| | i.name: helper.make_tensor_value_info(i.name, i.data_type, list(i.dims)) |
| | for i in self.out_mp_.graph.initializer |
| | } |
| | ) |
| |
|
| | def _merge_symbols(self, dims): |
| | if not all([type(d) == str for d in dims]): |
| | if self.auto_merge_: |
| | unique_dims = list(set(dims)) |
| | is_int = [is_literal(d) for d in unique_dims] |
| | assert sum(is_int) <= 1 |
| | if sum(is_int) == 1: |
| | int_dim = is_int.index(1) |
| | if self.verbose_ > 0: |
| | logger.debug( |
| | "dim {} has been merged with value {}".format( |
| | unique_dims[:int_dim] + unique_dims[int_dim + 1 :], |
| | unique_dims[int_dim], |
| | ) |
| | ) |
| | self._check_merged_dims(unique_dims, allow_broadcast=False) |
| | return unique_dims[int_dim] |
| | else: |
| | if self.verbose_ > 0: |
| | logger.debug(f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}") |
| | return dims[0] |
| | else: |
| | return None |
| | if all([d == dims[0] for d in dims]): |
| | return dims[0] |
| | merged = [self.suggested_merge_.get(d, d) for d in dims] |
| | if all([d == merged[0] for d in merged]): |
| | assert merged[0] in self.symbolic_dims_ |
| | return merged[0] |
| | else: |
| | return None |
| |
|
| | |
| | def _broadcast_shapes(self, shape1, shape2): |
| | new_shape = [] |
| | rank1 = len(shape1) |
| | rank2 = len(shape2) |
| | new_rank = max(rank1, rank2) |
| | for i in range(new_rank): |
| | dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1 |
| | dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1 |
| | if dim1 == 1 or dim1 == dim2: |
| | new_dim = dim2 |
| | elif dim2 == 1: |
| | new_dim = dim1 |
| | else: |
| | new_dim = self._merge_symbols([dim1, dim2]) |
| | if not new_dim: |
| | |
| | |
| | |
| | if self.auto_merge_: |
| | self._add_suggested_merge([dim1, dim2], apply=True) |
| | else: |
| | logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2)) |
| | new_shape = [new_dim, *new_shape] |
| | return new_shape |
| |
|
| | def _get_shape(self, node, idx): |
| | name = node.input[idx] |
| | if name in self.known_vi_: |
| | vi = self.known_vi_[name] |
| | return get_shape_from_value_info(vi) |
| | else: |
| | assert name in self.initializers_ |
| | return list(self.initializers_[name].dims) |
| |
|
| | def _try_get_shape(self, node, idx): |
| | if idx > len(node.input) - 1: |
| | return None |
| | name = node.input[idx] |
| | if name in self.known_vi_: |
| | vi = self.known_vi_[name] |
| | return get_shape_from_value_info(vi) |
| | if name in self.initializers_: |
| | return list(self.initializers_[name].dims) |
| | return None |
| |
|
| | def _get_shape_rank(self, node, idx): |
| | return len(self._get_shape(node, idx)) |
| |
|
| | def _get_sympy_shape(self, node, idx): |
| | sympy_shape = [] |
| | for d in self._get_shape(node, idx): |
| | if type(d) == str: |
| | sympy_shape.append( |
| | self.symbolic_dims_[d] |
| | if d in self.symbolic_dims_ |
| | else sympy.Symbol(d, integer=True, nonnegative=True) |
| | ) |
| | else: |
| | assert None is not d |
| | sympy_shape.append(d) |
| | return sympy_shape |
| |
|
| | def _get_value(self, node, idx): |
| | name = node.input[idx] |
| | assert name in self.sympy_data_ or name in self.initializers_ |
| | return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name]) |
| |
|
| | def _try_get_value(self, node, idx): |
| | if idx >= len(node.input): |
| | return None |
| | name = node.input[idx] |
| | if name in self.sympy_data_ or name in self.initializers_: |
| | return self._get_value(node, idx) |
| | return None |
| |
|
| | def _update_computed_dims(self, new_sympy_shape): |
| | for i, new_dim in enumerate(new_sympy_shape): |
| | if not is_literal(new_dim) and type(new_dim) != str: |
| | str_dim = str(new_dim) |
| | if str_dim in self.suggested_merge_: |
| | if is_literal(self.suggested_merge_[str_dim]): |
| | continue |
| | new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]] |
| | else: |
| | |
| | if str(new_dim) not in self.symbolic_dims_: |
| | self.symbolic_dims_[str(new_dim)] = new_dim |
| |
|
| | def _onnx_infer_single_node(self, node): |
| | |
| | skip_infer = node.op_type in [ |
| | "If", |
| | "Loop", |
| | "Scan", |
| | "SplitToSequence", |
| | "ZipMap", |
| | "Attention", |
| | "BiasGelu", |
| | "EmbedLayerNormalization", |
| | "FastGelu", |
| | "Gelu", |
| | "GemmFastGelu", |
| | "LayerNormalization", |
| | "LongformerAttention", |
| | "DequantizeLinear", |
| | "QuantizeLinear", |
| | "RelativePositionBias", |
| | "RemovePadding", |
| | "RestorePadding", |
| | "SimplifiedLayerNormalization", |
| | "SkipLayerNormalization", |
| | "SkipSimplifiedLayerNormalization", |
| | "PackedAttention", |
| | "PagedAttention", |
| | "PythonOp", |
| | "MultiHeadAttention", |
| | "GroupNorm", |
| | "GroupQueryAttention", |
| | "SkipGroupNorm", |
| | "BiasSplitGelu", |
| | "BiasAdd", |
| | "NhwcConv", |
| | "QuickGelu", |
| | "RotaryEmbedding", |
| | ] |
| |
|
| | if not skip_infer: |
| | |
| | |
| | |
| | |
| | |
| | initializers = [] |
| | if (get_opset(self.out_mp_) >= 9) and node.op_type in ["Unsqueeze"]: |
| | initializers = [ |
| | self.initializers_[name] |
| | for name in node.input |
| | if (name in self.initializers_ and name not in self.graph_inputs_) |
| | ] |
| |
|
| | |
| | tmp_graph = helper.make_graph( |
| | [node], |
| | "tmp", |
| | [self.known_vi_[i] for i in node.input if i], |
| | [make_named_value_info(i) for i in node.output], |
| | initializers, |
| | ) |
| |
|
| | self.tmp_mp_.graph.CopyFrom(tmp_graph) |
| |
|
| | self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_) |
| |
|
| | for i_o in range(len(node.output)): |
| | o = node.output[i_o] |
| | if o: |
| | vi = self.out_mp_.graph.value_info.add() |
| | if not skip_infer: |
| | vi.CopyFrom(self.tmp_mp_.graph.output[i_o]) |
| | else: |
| | vi.name = o |
| | self.known_vi_[o] = vi |
| |
|
| | def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True): |
| | if self.verbose_ > 2: |
| | logger.debug(f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}") |
| | |
| | |
| | |
| | |
| | subgraph_inputs = {i.name for i in list(subgraph.initializer) + list(subgraph.input)} |
| | subgraph_implicit_input = {name for name in self.known_vi_ if name not in subgraph_inputs} |
| | tmp_graph = helper.make_graph( |
| | list(subgraph.node), |
| | "tmp", |
| | list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input], |
| | [make_named_value_info(i.name) for i in subgraph.output], |
| | ) |
| | tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input]) |
| | tmp_graph.initializer.extend(subgraph.initializer) |
| | self.tmp_mp_.graph.CopyFrom(tmp_graph) |
| |
|
| | symbolic_shape_inference = SymbolicShapeInference( |
| | self.int_max_, |
| | self.auto_merge_, |
| | self.guess_output_rank_, |
| | self.verbose_, |
| | prefix=self.prefix_ + "_" + str(self.subgraph_id_), |
| | ) |
| | if inc_subgraph_id: |
| | self.subgraph_id_ += 1 |
| |
|
| | symbolic_shape_inference._preprocess(self.tmp_mp_) |
| | symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy() |
| | while symbolic_shape_inference.run_: |
| | symbolic_shape_inference._infer_impl(self.sympy_data_.copy()) |
| | symbolic_shape_inference._update_output_from_vi() |
| | if use_node_input: |
| | |
| | subgraph.ClearField("input") |
| | subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)]) |
| | subgraph.ClearField("output") |
| | subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) |
| | subgraph.ClearField("value_info") |
| | subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info) |
| | subgraph.ClearField("node") |
| | subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) |
| | |
| | subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output] |
| | subgraph_new_symbolic_dims = { |
| | d for s in subgraph_shapes if s for d in s if type(d) == str and d not in self.symbolic_dims_ |
| | } |
| | new_dims = {} |
| | for d in subgraph_new_symbolic_dims: |
| | assert d in symbolic_shape_inference.symbolic_dims_ |
| | new_dims[d] = symbolic_shape_inference.symbolic_dims_[d] |
| | self.symbolic_dims_.update(new_dims) |
| | return symbolic_shape_inference |
| |
|
| | def _get_int_or_float_values(self, node, broadcast=False, allow_float_values=False): |
| | def int_or_float(value, allow_float_values): |
| | |
| | if allow_float_values and value % 1 != 0: |
| | return value |
| | return int(value) |
| |
|
| | values = [self._try_get_value(node, i) for i in range(len(node.input))] |
| | if all([v is not None for v in values]): |
| | |
| | for i, v in enumerate(values): |
| | if type(v) != np.ndarray: |
| | continue |
| | if len(v.shape) > 1: |
| | new_v = None |
| | elif len(v.shape) == 0: |
| | new_v = int_or_float(v.item(), allow_float_values) |
| | else: |
| | assert len(v.shape) == 1 |
| | new_v = [int_or_float(vv, allow_float_values) for vv in v] |
| | values[i] = new_v |
| | values_len = [len(v) if isinstance(v, list) else 0 for v in values] |
| | max_len = max(values_len) |
| | if max_len >= 1 and broadcast: |
| | |
| | for i, v in enumerate(values): |
| | if v is None: |
| | continue |
| | if isinstance(v, list): |
| | if len(v) < max_len: |
| | values[i] = v * max_len |
| | else: |
| | assert len(v) == max_len |
| | else: |
| | values[i] = [v] * max_len |
| | return values |
| |
|
| | def _compute_on_sympy_data(self, node, op_func): |
| | assert len(node.output) == 1 |
| |
|
| | |
| | |
| | |
| | if node.op_type in ["Mul", "Div"]: |
| | values = self._get_int_or_float_values(node, broadcast=True, allow_float_values=True) |
| | else: |
| | values = self._get_int_or_float_values(node, broadcast=True) |
| |
|
| | if all([v is not None for v in values]): |
| | is_list = [isinstance(v, list) for v in values] |
| | as_list = any(is_list) |
| | if as_list: |
| | self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)] |
| | else: |
| | self.sympy_data_[node.output[0]] = op_func(values) |
| |
|
| | def _pass_on_sympy_data(self, node): |
| | assert len(node.input) == 1 or node.op_type in [ |
| | "Reshape", |
| | "Unsqueeze", |
| | "Squeeze", |
| | ] |
| | self._compute_on_sympy_data(node, lambda x: x[0]) |
| |
|
| | def _pass_on_shape_and_type(self, node): |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type), |
| | self._get_shape(node, 0), |
| | ) |
| | ) |
| |
|
| | def _new_symbolic_dim(self, prefix, dim): |
| | new_dim = f"{prefix}_d{dim}" |
| | if new_dim in self.suggested_merge_: |
| | v = self.suggested_merge_[new_dim] |
| | new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v |
| | else: |
| | new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True) |
| | self.symbolic_dims_[new_dim] = new_symbolic_dim |
| | return new_symbolic_dim |
| |
|
| | def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): |
| | return self._new_symbolic_dim( |
| | "{}{}_{}_o{}_".format( |
| | node.op_type, |
| | self.prefix_, |
| | list(self.out_mp_.graph.node).index(node), |
| | out_idx, |
| | ), |
| | dim, |
| | ) |
| |
|
| | def _new_symbolic_shape(self, rank, node, out_idx=0): |
| | return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)] |
| |
|
| | def _compute_conv_pool_shape(self, node, channels_last=False): |
| | sympy_shape = self._get_sympy_shape(node, 0) |
| | if len(node.input) > 1: |
| | W_shape = self._get_sympy_shape(node, 1) |
| | rank = len(W_shape) - 2 |
| | kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:] |
| | sympy_shape[3 if channels_last else 1] = W_shape[0] |
| | else: |
| | W_shape = None |
| | kernel_shape = get_attribute(node, "kernel_shape") |
| | rank = len(kernel_shape) |
| |
|
| | assert len(sympy_shape) == rank + 2 |
| |
|
| | |
| | spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:] |
| | is_symbolic_dims = [not is_literal(i) for i in spatial_shape] |
| |
|
| | if not any(is_symbolic_dims): |
| | shape = get_shape_from_value_info(self.known_vi_[node.output[0]]) |
| | if len(shape) > 0: |
| | assert len(sympy_shape) == len(shape) |
| | if channels_last: |
| | sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]] |
| | else: |
| | sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] |
| | return sympy_shape |
| |
|
| | dilations = get_attribute(node, "dilations", [1] * rank) |
| | strides = get_attribute(node, "strides", [1] * rank) |
| | effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)] |
| | pads = get_attribute(node, "pads") |
| | if pads is None: |
| | pads = [0] * (2 * rank) |
| | auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8") |
| | if auto_pad != "VALID" and auto_pad != "NOTSET": |
| | try: |
| | residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)] |
| | total_pads = [ |
| | max(0, (k - s) if r == 0 else (k - r)) |
| | for k, s, r in zip(effective_kernel_shape, strides, residual) |
| | ] |
| | except TypeError: |
| | total_pads = [ |
| | max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides) |
| | ] |
| | elif auto_pad == "VALID": |
| | total_pads = [] |
| | else: |
| | total_pads = [0] * rank |
| | else: |
| | assert len(pads) == 2 * rank |
| | total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])] |
| |
|
| | ceil_mode = get_attribute(node, "ceil_mode", 0) |
| | for i in range(rank): |
| | effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)] |
| | if len(total_pads) > 0: |
| | effective_input_size = effective_input_size + total_pads[i] |
| | if ceil_mode: |
| | strided_kernel_positions = sympy.ceiling( |
| | (effective_input_size - effective_kernel_shape[i]) / strides[i] |
| | ) |
| | else: |
| | strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i] |
| | sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1 |
| | return sympy_shape |
| |
|
| | def _check_merged_dims(self, dims, allow_broadcast=True): |
| | if allow_broadcast: |
| | dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)] |
| | if not all([d == dims[0] for d in dims]): |
| | self._add_suggested_merge(dims, apply=True) |
| |
|
| | def _compute_matmul_shape(self, node, output_dtype=None): |
| | lhs_shape = self._get_shape(node, 0) |
| | rhs_shape = self._get_shape(node, 1) |
| | lhs_rank = len(lhs_shape) |
| | rhs_rank = len(rhs_shape) |
| | lhs_reduce_dim = 0 |
| | rhs_reduce_dim = 0 |
| | assert lhs_rank > 0 and rhs_rank > 0 |
| | if lhs_rank == 1 and rhs_rank == 1: |
| | new_shape = [] |
| | elif lhs_rank == 1: |
| | rhs_reduce_dim = -2 |
| | new_shape = rhs_shape[:rhs_reduce_dim] + [rhs_shape[-1]] |
| | elif rhs_rank == 1: |
| | lhs_reduce_dim = -1 |
| | new_shape = lhs_shape[:lhs_reduce_dim] |
| | else: |
| | lhs_reduce_dim = -1 |
| | rhs_reduce_dim = -2 |
| | new_shape = [*self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]), lhs_shape[-2], rhs_shape[-1]] |
| | |
| | self._check_merged_dims( |
| | [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], |
| | allow_broadcast=False, |
| | ) |
| | if output_dtype is None: |
| | |
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) |
| |
|
| | def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): |
| | """ |
| | update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches |
| | """ |
| | dst_tensor_type = ( |
| | dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type |
| | ) |
| | src_tensor_type = ( |
| | src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type |
| | ) |
| | if dst_tensor_type.elem_type != src_tensor_type.elem_type: |
| | node_id = node.name if node.name else node.op_type |
| | raise ValueError( |
| | f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " |
| | f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " |
| | f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" |
| | ) |
| | if dst_tensor_type.HasField("shape"): |
| | for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): |
| | if ds[0] != ds[1]: |
| | |
| | |
| | new_dim = onnx.TensorShapeProto.Dimension() |
| | if not is_sequence(dst_type): |
| | new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, out_idx, di)) |
| | dst_tensor_type.shape.dim[di].CopyFrom(new_dim) |
| | else: |
| | dst_tensor_type.CopyFrom(src_tensor_type) |
| |
|
| | def _infer_ArrayFeatureExtractor(self, node): |
| | data_shape = self._get_shape(node, 0) |
| | indices_shape = self._get_shape(node, 1) |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | data_shape[:-1] + indices_shape, |
| | ) |
| | ) |
| |
|
| | def _infer_symbolic_compute_ops(self, node): |
| | funcs = { |
| | "Add": lambda l: l[0] + l[1], |
| | "Div": lambda l: ( |
| | int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1] |
| | ), |
| | "Equal": lambda l: l[0] == l[1], |
| | "Floor": lambda l: sympy.floor(l[0]), |
| | "Max": lambda l: ( |
| | l[1] |
| | if is_literal(l[0]) and int(l[0]) < -self.int_max_ |
| | else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])) |
| | ), |
| | "Min": lambda l: ( |
| | l[1] |
| | if is_literal(l[0]) and int(l[0]) > self.int_max_ |
| | else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])) |
| | ), |
| | "Mul": lambda l: int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1], |
| | "Sub": lambda l: l[0] - l[1], |
| | "Where": lambda l: l[1] if l[0] else l[2], |
| | "Neg": lambda l: -l[0], |
| | } |
| | assert node.op_type in funcs |
| | self._compute_on_sympy_data(node, funcs[node.op_type]) |
| |
|
| | def _infer_Cast(self, node): |
| | self._pass_on_sympy_data(node) |
| |
|
| | def _infer_CategoryMapper(self, node): |
| | input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | if input_type == onnx.TensorProto.STRING: |
| | output_type = onnx.TensorProto.INT64 |
| | else: |
| | output_type = onnx.TensorProto.STRING |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0))) |
| |
|
| | def _infer_Compress(self, node): |
| | input_shape = self._get_shape(node, 0) |
| | |
| | compress_len = str(self._new_symbolic_dim_from_output(node)) |
| | axis = get_attribute(node, "axis") |
| | if axis is None: |
| | |
| | output_shape = [compress_len] |
| | else: |
| | output_shape = input_shape |
| | output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | output_shape, |
| | ) |
| | ) |
| |
|
| | def _infer_Concat(self, node): |
| | if any([i in self.sympy_data_ or i in self.initializers_ for i in node.input]): |
| | values = self._get_int_or_float_values(node) |
| | if all([v is not None for v in values]): |
| | assert get_attribute(node, "axis") == 0 |
| | self.sympy_data_[node.output[0]] = [] |
| | for i in range(len(node.input)): |
| | value = values[i] |
| | if isinstance(value, list): |
| | self.sympy_data_[node.output[0]].extend(value) |
| | else: |
| | self.sympy_data_[node.output[0]].append(value) |
| |
|
| | sympy_shape = self._get_sympy_shape(node, 0) |
| | axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape)) |
| | for i_idx in range(1, len(node.input)): |
| | input_shape = self._get_sympy_shape(node, i_idx) |
| | if input_shape: |
| | sympy_shape[axis] = sympy_shape[axis] + input_shape[axis] |
| | self._update_computed_dims(sympy_shape) |
| | |
| | for d in range(len(sympy_shape)): |
| | if d == axis: |
| | continue |
| | dims = [self._get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if self._get_shape(node, i_idx)] |
| | if all([d == dims[0] for d in dims]): |
| | continue |
| | merged = self._merge_symbols(dims) |
| | if type(merged) == str: |
| | sympy_shape[d] = self.symbolic_dims_[merged] if merged else None |
| | else: |
| | sympy_shape[d] = merged |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(sympy_shape), |
| | ) |
| | ) |
| |
|
| | def _infer_ConcatFromSequence(self, node): |
| | seq_shape = self._get_shape(node, 0) |
| | new_axis = 1 if get_attribute(node, "new_axis") else 0 |
| | axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis) |
| | concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) |
| | new_shape = seq_shape |
| | if new_axis: |
| | new_shape = seq_shape[:axis] + [concat_dim] + seq_shape[axis:] |
| | else: |
| | new_shape[axis] = concat_dim |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type, |
| | new_shape, |
| | ) |
| | ) |
| |
|
| | def _infer_Constant(self, node): |
| | t = get_attribute(node, "value") |
| | self.sympy_data_[node.output[0]] = numpy_helper.to_array(t) |
| |
|
| | def _infer_ConstantOfShape(self, node): |
| | sympy_shape = self._get_int_or_float_values(node)[0] |
| | vi = self.known_vi_[node.output[0]] |
| | if sympy_shape is not None: |
| | if type(sympy_shape) != list: |
| | sympy_shape = [sympy_shape] |
| | self._update_computed_dims(sympy_shape) |
| | |
| | if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]): |
| | self.sympy_data_[node.output[0]] = np.ones( |
| | [int(x) for x in sympy_shape], dtype=np.int64 |
| | ) * numpy_helper.to_array(get_attribute(node, "value", 0)) |
| | else: |
| | |
| | |
| | sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node) |
| |
|
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | vi.type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(sympy_shape), |
| | ) |
| | ) |
| |
|
| | def _infer_Conv(self, node): |
| | sympy_shape = self._compute_conv_pool_shape(node) |
| | self._update_computed_dims(sympy_shape) |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | vi.type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(sympy_shape), |
| | ) |
| | ) |
| |
|
| | def _infer_NhwcConv(self, node): |
| | sympy_shape = self._compute_conv_pool_shape(node, channels_last=True) |
| | self._update_computed_dims(sympy_shape) |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(sympy_shape), |
| | ) |
| | ) |
| |
|
| | def _infer_DequantizeLinear(self, node): |
| | |
| | output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type |
| |
|
| | |
| | output_shape = self._get_shape(node, 0) |
| |
|
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
| |
|
| | def _infer_QuantizeLinear(self, node): |
| | |
| | |
| | output_dtype = onnx.TensorProto.UINT8 |
| | if len(node.input) > 2 and node.input[2]: |
| | output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type |
| |
|
| | |
| | output_shape = self._get_shape(node, 0) |
| |
|
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
| |
|
| | def _infer_Einsum(self, node): |
| | |
| | equation = get_attribute(node, "equation") |
| | equation = equation.replace(b" ", b"") |
| | mid_index = equation.find(b"->") |
| | left_equation = equation[:mid_index] if mid_index != -1 else equation |
| |
|
| | num_operands = 0 |
| | num_ellipsis = 0 |
| | num_ellipsis_indices = 0 |
| |
|
| | letter_to_dim = {} |
| |
|
| | terms = left_equation.split(b",") |
| | for term in terms: |
| | ellipsis_index = term.find(b"...") |
| | shape = self._get_shape(node, num_operands) |
| | rank = len(shape) |
| | if ellipsis_index != -1: |
| | if num_ellipsis == 0: |
| | num_ellipsis_indices = rank - len(term) + 3 |
| | num_ellipsis = num_ellipsis + 1 |
| | for i in range(1, rank + 1): |
| | letter = term[-i] |
| | if letter != 46: |
| | dim = shape[-i] |
| | if letter not in letter_to_dim: |
| | letter_to_dim[letter] = dim |
| | elif type(dim) != sympy.Symbol: |
| | letter_to_dim[letter] = dim |
| | num_operands = num_operands + 1 |
| |
|
| | new_sympy_shape = [] |
| | from collections import OrderedDict |
| |
|
| | num_letter_occurrences = OrderedDict() |
| | if mid_index != -1: |
| | right_equation = equation[mid_index + 2 :] |
| | right_ellipsis_index = right_equation.find(b"...") |
| | if right_ellipsis_index != -1: |
| | for i in range(num_ellipsis_indices): |
| | new_sympy_shape.append(shape[i]) |
| | for c in right_equation: |
| | if c != 46: |
| | new_sympy_shape.append(letter_to_dim[c]) |
| | else: |
| | for i in range(num_ellipsis_indices): |
| | new_sympy_shape.append(shape[i]) |
| | for c in left_equation: |
| | if c != 44 and c != 46: |
| | if c in num_letter_occurrences: |
| | num_letter_occurrences[c] = num_letter_occurrences[c] + 1 |
| | else: |
| | num_letter_occurrences[c] = 1 |
| | for key, value in num_letter_occurrences.items(): |
| | if value == 1: |
| | new_sympy_shape.append(letter_to_dim[key]) |
| |
|
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape)) |
| |
|
| | def _infer_Expand(self, node): |
| | expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) |
| | if expand_to_shape is not None: |
| | |
| | self._update_computed_dims(expand_to_shape) |
| | shape = self._get_shape(node, 0) |
| | new_shape = self._broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape)) |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | new_shape, |
| | ) |
| | ) |
| |
|
| | def _infer_Gather(self, node): |
| | data_shape = self._get_shape(node, 0) |
| | axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape)) |
| | indices_shape = self._get_shape(node, 1) |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | data_shape[:axis] + indices_shape + data_shape[axis + 1 :], |
| | ) |
| | ) |
| | |
| | if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and get_attribute(node, "axis", 0) == 0: |
| | idx = self._try_get_value(node, 1) |
| | if idx is not None: |
| | data = self.sympy_data_[node.input[0]] |
| | if type(data) == list: |
| | if type(idx) == np.ndarray and len(idx.shape) == 1: |
| | self.sympy_data_[node.output[0]] = [data[int(i)] for i in idx] |
| | else: |
| | self.sympy_data_[node.output[0]] = data[int(idx)] |
| | else: |
| | assert idx == 0 or idx == -1 |
| | self.sympy_data_[node.output[0]] = data |
| |
|
| | def _infer_GatherElements(self, node): |
| | indices_shape = self._get_shape(node, 1) |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | indices_shape, |
| | ) |
| | ) |
| |
|
| | def _infer_GatherND(self, node): |
| | data_shape = self._get_shape(node, 0) |
| | data_rank = len(data_shape) |
| | indices_shape = self._get_shape(node, 1) |
| | len(indices_shape) |
| | last_index_dimension = indices_shape[-1] |
| | assert is_literal(last_index_dimension) and last_index_dimension <= data_rank |
| | new_shape = indices_shape[:-1] + data_shape[last_index_dimension:] |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | new_shape, |
| | ) |
| | ) |
| |
|
| | def _infer_If(self, node): |
| | |
| | subgraphs = [ |
| | get_attribute(node, "then_branch"), |
| | get_attribute(node, "else_branch"), |
| | ] |
| | cond = self._try_get_value(node, 0) |
| | if cond is not None: |
| | if as_scalar(cond) > 0: |
| | subgraphs[1].CopyFrom(subgraphs[0]) |
| | else: |
| | subgraphs[0].CopyFrom(subgraphs[1]) |
| |
|
| | for i_sub, subgraph in enumerate(subgraphs): |
| | subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False) |
| | for i_out in range(len(node.output)): |
| | vi = self.known_vi_[node.output[i_out]] |
| | if i_sub == 0: |
| | vi.CopyFrom(subgraph.output[i_out]) |
| | vi.name = node.output[i_out] |
| | else: |
| | self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type) |
| |
|
| | |
| | if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else 1): |
| | if subgraph.output[i_out].name in subgraph_infer.sympy_data_: |
| | self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name] |
| |
|
| | def _infer_Loop(self, node): |
| | subgraph = get_attribute(node, "body") |
| | assert len(subgraph.input) == len(node.input) |
| | num_loop_carried = len(node.input) - 2 |
| | |
| | |
| | for i, si in enumerate(subgraph.input): |
| | si_name = si.name |
| | si.CopyFrom(self.known_vi_[node.input[i]]) |
| | si.name = si_name |
| |
|
| | self._onnx_infer_subgraph(node, subgraph) |
| |
|
| | |
| | |
| | |
| | need_second_infer = False |
| | for i_out in range(1, num_loop_carried + 1): |
| | so = subgraph.output[i_out] |
| | so_shape = get_shape_from_value_info(so) |
| | if is_sequence(so.type): |
| | if so_shape and None in so_shape: |
| | |
| | |
| | |
| | subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type) |
| | need_second_infer = True |
| | else: |
| | si = subgraph.input[i_out + 1] |
| | si_shape = get_shape_from_value_info(si) |
| | for di, dims in enumerate(zip(si_shape, so_shape)): |
| | if dims[0] != dims[1]: |
| | new_dim = onnx.TensorShapeProto.Dimension() |
| | new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, i_out, di)) |
| | si.type.tensor_type.shape.dim[di].CopyFrom(new_dim) |
| | so.type.tensor_type.shape.dim[di].CopyFrom(new_dim) |
| | need_second_infer = True |
| |
|
| | if need_second_infer: |
| | if self.verbose_ > 2: |
| | logger.debug( |
| | "Rerun Loop: {}({}...), because of sequence in loop carried variables".format( |
| | node.name, node.output[0] |
| | ) |
| | ) |
| | self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False) |
| |
|
| | |
| | loop_iter_dim = str(self._new_symbolic_dim_from_output(node)) |
| | for i in range(len(node.output)): |
| | vi = self.known_vi_[node.output[i]] |
| | vi.CopyFrom(subgraph.output[i + 1]) |
| | if i >= num_loop_carried: |
| | assert not is_sequence(vi.type) |
| | subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim |
| | vi.type.tensor_type.shape.ClearField("dim") |
| | vi_dim = vi.type.tensor_type.shape.dim |
| | vi_dim.add().dim_param = loop_iter_dim |
| | vi_dim.extend(list(subgraph_vi_dim)) |
| | vi.name = node.output[i] |
| |
|
| | def _infer_MatMul(self, node): |
| | self._compute_matmul_shape(node) |
| |
|
| | def _infer_MatMulInteger(self, node): |
| | self._compute_matmul_shape(node, onnx.TensorProto.INT32) |
| |
|
| | def _infer_NonMaxSuppression(self, node): |
| | selected = str(self._new_symbolic_dim_from_output(node)) |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3])) |
| |
|
| | def _infer_NonZero(self, node): |
| | input_rank = self._get_shape_rank(node, 0) |
| | |
| | nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1)) |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len])) |
| |
|
| | def _infer_OneHot(self, node): |
| | sympy_shape = self._get_sympy_shape(node, 0) |
| | depth = self._try_get_value(node, 1) |
| | axis = get_attribute(node, "axis", -1) |
| | axis = handle_negative_axis(axis, len(sympy_shape) + 1) |
| | new_shape = get_shape_from_sympy_shape( |
| | sympy_shape[:axis] |
| | + [self._new_symbolic_dim_from_output(node) if not is_literal(depth) else depth] |
| | + sympy_shape[axis:] |
| | ) |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[2]].type.tensor_type.elem_type, |
| | new_shape, |
| | ) |
| | ) |
| |
|
| | def _infer_Pad(self, node): |
| | if get_opset(self.out_mp_) <= 10: |
| | pads = get_attribute(node, "pads") |
| | else: |
| | pads = self._try_get_value(node, 1) |
| |
|
| | sympy_shape = self._get_sympy_shape(node, 0) |
| | rank = len(sympy_shape) |
| |
|
| | if pads is not None: |
| | assert len(pads) == 2 * rank |
| | new_sympy_shape = [ |
| | d + pad_up + pad_down for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:]) |
| | ] |
| | self._update_computed_dims(new_sympy_shape) |
| | else: |
| | |
| | new_sympy_shape = self._new_symbolic_shape(rank, node) |
| | output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| |
|
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)) |
| | ) |
| |
|
| | def _infer_Pool(self, node): |
| | sympy_shape = self._compute_conv_pool_shape(node) |
| | self._update_computed_dims(sympy_shape) |
| | for o in node.output: |
| | if not o: |
| | continue |
| | vi = self.known_vi_[o] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | o, |
| | vi.type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(sympy_shape), |
| | ) |
| | ) |
| |
|
| | def _infer_aten_bitwise_or(self, node): |
| | shape0 = self._get_shape(node, 0) |
| | shape1 = self._get_shape(node, 1) |
| | new_shape = self._broadcast_shapes(shape0, shape1) |
| | t0 = self.known_vi_[node.input[0]] |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape)) |
| |
|
| | def _infer_aten_diagonal(self, node): |
| | sympy_shape = self._get_sympy_shape(node, 0) |
| | rank = len(sympy_shape) |
| | offset = self._try_get_value(node, 1) |
| | dim1 = self._try_get_value(node, 2) |
| | dim2 = self._try_get_value(node, 3) |
| |
|
| | assert offset is not None and dim1 is not None and dim2 is not None |
| | dim1 = handle_negative_axis(dim1, rank) |
| | dim2 = handle_negative_axis(dim2, rank) |
| |
|
| | new_shape = [] |
| | for dim, val in enumerate(sympy_shape): |
| | if dim not in [dim1, dim2]: |
| | new_shape.append(val) |
| |
|
| | shape1 = sympy_shape[dim1] |
| | shape2 = sympy_shape[dim2] |
| | if offset >= 0: |
| | diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset)) |
| | else: |
| | diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2)) |
| | new_shape.append(diag_shape) |
| |
|
| | if node.output[0]: |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(new_shape), |
| | ) |
| | ) |
| |
|
| | def _infer_aten_multinomial(self, node): |
| | sympy_shape = self._get_sympy_shape(node, 0) |
| | rank = len(sympy_shape) |
| | assert rank in [1, 2] |
| | num_samples = self._try_get_value(node, 1) |
| | di = rank - 1 |
| | last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di)) |
| | output_shape = sympy_shape[:-1] + [last_dim] |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | onnx.TensorProto.INT64, |
| | get_shape_from_sympy_shape(output_shape), |
| | ) |
| | ) |
| |
|
| | def _infer_aten_pool2d(self, node): |
| | sympy_shape = self._get_sympy_shape(node, 0) |
| | assert len(sympy_shape) == 4 |
| | sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]] |
| | self._update_computed_dims(sympy_shape) |
| | for i, o in enumerate(node.output): |
| | if not o: |
| | continue |
| | vi = self.known_vi_[o] |
| | elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape))) |
| |
|
| | def _infer_aten_minmax(self, node): |
| | vi = self.known_vi_[node.output[0]] |
| | if len(node.input) == 1: |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, [] |
| | ) |
| | ) |
| | else: |
| | assert len(node.input) == 3 |
| | keepdim = self._try_get_value(node, 2) |
| | assert keepdim is not None |
| | dim = self._try_get_value(node, 1) |
| | if dim is None: |
| | rank = self._get_shape_rank(node, 0) |
| | output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) |
| | else: |
| | shape = self._get_sympy_shape(node, 0) |
| | dim = handle_negative_axis(dim, len(shape)) |
| | output_shape = shape[:dim] |
| | if keepdim: |
| | output_shape += [1] |
| | output_shape += shape[dim + 1 :] |
| |
|
| | output_shape = get_shape_from_sympy_shape(output_shape) |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, output_shape |
| | ) |
| | ) |
| | vi1 = self.known_vi_[node.output[1]] |
| | vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape)) |
| |
|
| | def _infer_aten_unfold(self, node): |
| | sympy_shape = self._get_sympy_shape(node, 0) |
| | dimension = self._try_get_value(node, 1) |
| | size = self._try_get_value(node, 2) |
| | step = self._try_get_value(node, 3) |
| | if dimension is not None and size is not None and step is not None: |
| | assert dimension < len(sympy_shape) |
| | sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1 |
| | sympy_shape.append(size) |
| | else: |
| | rank = len(sympy_shape) |
| | sympy_shape = self._new_symbolic_shape(rank + 1, node) |
| | self._update_computed_dims(sympy_shape) |
| | if node.output[0]: |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(sympy_shape), |
| | ) |
| | ) |
| |
|
| | def _infer_aten_argmax(self, node): |
| | new_shape = None |
| | if not node.input[1]: |
| | |
| | new_shape = [] |
| | else: |
| | dim = self._try_get_value(node, 1) |
| | keepdim = self._try_get_value(node, 2) |
| | if keepdim is not None: |
| | sympy_shape = self._get_sympy_shape(node, 0) |
| | if dim is not None: |
| | dim = handle_negative_axis(dim, len(sympy_shape)) |
| | if keepdim: |
| | sympy_shape[dim] = 1 |
| | else: |
| | del sympy_shape[dim] |
| | else: |
| | rank = len(sympy_shape) |
| | sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) |
| | self._update_computed_dims(sympy_shape) |
| | new_shape = get_shape_from_sympy_shape(sympy_shape) |
| | if node.output[0] and new_shape is not None: |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape)) |
| |
|
| | def _infer_aten_group_norm(self, node): |
| | self._propagate_shape_and_type(node) |
| | input_shape = self._get_shape(node, 0) |
| | N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None |
| | group = self._try_get_value(node, 6) |
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | for i in [1, 2]: |
| | if node.output[i]: |
| | vi = self.known_vi_[node.output[i]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[i], |
| | output_dtype, |
| | [ |
| | N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0)), |
| | ( |
| | as_scalar(group) |
| | if group is not None |
| | else str(self._new_symbolic_dim_from_output(node, i, 1)) |
| | ), |
| | ], |
| | ) |
| | ) |
| |
|
| | def _infer_aten_upsample(self, node): |
| | new_shape = None |
| | input_shape = self._get_shape(node, 0) |
| | if input_shape is not None: |
| | new_shape = input_shape[:2] |
| | output_size = self._try_get_value(node, 1) |
| | if output_size is not None: |
| | new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size] |
| | else: |
| | rank = len(input_shape) |
| | new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)] |
| | if node.output[0] and new_shape is not None: |
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) |
| |
|
| | def _infer_BatchNormalization(self, node): |
| | self._propagate_shape_and_type(node) |
| |
|
| | |
| | for i in [1, 2, 3, 4]: |
| | if i < len(node.output) and node.output[i]: |
| | |
| | self._propagate_shape_and_type(node, input_index=1, output_index=i) |
| |
|
| | def _infer_Range(self, node): |
| | vi = self.known_vi_[node.output[0]] |
| | input_data = self._get_int_or_float_values(node) |
| | if all([i is not None for i in input_data]): |
| | start = as_scalar(input_data[0]) |
| | limit = as_scalar(input_data[1]) |
| | delta = as_scalar(input_data[2]) |
| | new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)] |
| | else: |
| | new_sympy_shape = [self._new_symbolic_dim_from_output(node)] |
| | self._update_computed_dims(new_sympy_shape) |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(new_sympy_shape), |
| | ) |
| | ) |
| |
|
| | def _infer_ReduceSum(self, node): |
| | keep_dims = get_attribute(node, "keepdims", 1) |
| | if get_opset(self.out_mp_) >= 13 and len(node.input) > 1: |
| | |
| | axes = self._try_get_value(node, 1) |
| | vi = self.known_vi_[node.output[0]] |
| | if axes is None: |
| | assert keep_dims |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)), |
| | ) |
| | ) |
| | else: |
| | shape = self._get_shape(node, 0) |
| | output_shape = [] |
| | axes = [handle_negative_axis(a, len(shape)) for a in axes] |
| | for i, d in enumerate(shape): |
| | if i in axes: |
| | if keep_dims: |
| | output_shape.append(1) |
| | else: |
| | output_shape.append(d) |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | output_shape, |
| | ) |
| | ) |
| |
|
| | def _infer_ReduceProd(self, node): |
| | axes = get_attribute(node, "axes") |
| | keep_dims = get_attribute(node, "keepdims", 1) |
| | if keep_dims == 0 and axes == [0]: |
| | data = self._get_int_or_float_values(node)[0] |
| | if data is not None: |
| | self.sympy_data_[node.output[0]] = sympy_reduce_product(data) |
| |
|
| | def _infer_RelativePositionBias(self, node): |
| | seq_len = self._try_get_value(node, 1) |
| | real_seq_len = self._try_get_value(node, 2) |
| | if seq_len is None or real_seq_len is None: |
| | return |
| | num_heads = self._get_sympy_shape(node, 0)[1] |
| |
|
| | new_shape = [1, num_heads, str(seq_len), str(real_seq_len)] |
| |
|
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) |
| |
|
| | def _infer_Reshape(self, node): |
| | shape_value = self._try_get_value(node, 1) |
| | vi = self.known_vi_[node.output[0]] |
| | if shape_value is None: |
| | shape_shape = self._get_shape(node, 1) |
| | assert len(shape_shape) == 1 |
| | shape_rank = shape_shape[0] |
| | assert is_literal(shape_rank) |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | vi.type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)), |
| | ) |
| | ) |
| | else: |
| | input_sympy_shape = self._get_sympy_shape(node, 0) |
| | total = 1 |
| | for d in input_sympy_shape: |
| | total = total * d |
| | new_sympy_shape = [] |
| | deferred_dim_idx = -1 |
| | non_deferred_size = 1 |
| | for i, d in enumerate(shape_value): |
| | if type(d) == sympy.Symbol: |
| | new_sympy_shape.append(d) |
| | elif d == 0: |
| | new_sympy_shape.append(input_sympy_shape[i]) |
| | non_deferred_size = non_deferred_size * input_sympy_shape[i] |
| | else: |
| | new_sympy_shape.append(d) |
| | if d == -1: |
| | deferred_dim_idx = i |
| | elif d != 0: |
| | non_deferred_size = non_deferred_size * d |
| |
|
| | assert new_sympy_shape.count(-1) < 2 |
| | if -1 in new_sympy_shape: |
| | new_dim = total // non_deferred_size |
| | new_sympy_shape[deferred_dim_idx] = new_dim |
| |
|
| | self._update_computed_dims(new_sympy_shape) |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | vi.type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(new_sympy_shape), |
| | ) |
| | ) |
| |
|
| | self._pass_on_sympy_data(node) |
| |
|
| | def _infer_Resize(self, node): |
| | vi = self.known_vi_[node.output[0]] |
| | input_sympy_shape = self._get_sympy_shape(node, 0) |
| | if get_opset(self.out_mp_) <= 10: |
| | scales = self._try_get_value(node, 1) |
| | if scales is not None: |
| | new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)] |
| | self._update_computed_dims(new_sympy_shape) |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(new_sympy_shape), |
| | ) |
| | ) |
| | else: |
| | roi = self._try_get_value(node, 1) |
| | scales = self._try_get_value(node, 2) |
| | sizes = self._try_get_value(node, 3) |
| | if sizes is not None: |
| | new_sympy_shape = [sympy.simplify(sympy.floor(s)) for s in sizes] |
| | self._update_computed_dims(new_sympy_shape) |
| | elif scales is not None: |
| | rank = len(scales) |
| | if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize": |
| | assert len(roi) == 2 * rank |
| | roi_start = list(roi)[:rank] |
| | roi_end = list(roi)[rank:] |
| | else: |
| | roi_start = [0] * rank |
| | roi_end = [1] * rank |
| | scales = list(scales) |
| | new_sympy_shape = [ |
| | sympy.simplify(sympy.floor(d * (end - start) * scale)) |
| | for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales) |
| | ] |
| | self._update_computed_dims(new_sympy_shape) |
| | else: |
| | new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) |
| |
|
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(new_sympy_shape), |
| | ) |
| | ) |
| |
|
| | def _infer_Scan(self, node): |
| | subgraph = get_attribute(node, "body") |
| | num_scan_inputs = get_attribute(node, "num_scan_inputs") |
| | scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs) |
| | num_scan_states = len(node.input) - num_scan_inputs |
| | scan_input_axes = [ |
| | handle_negative_axis(ax, self._get_shape_rank(node, i + num_scan_states)) |
| | for i, ax in enumerate(scan_input_axes) |
| | ] |
| | |
| | |
| | assert len(subgraph.input) >= len(node.input) |
| | subgraph_inputs = subgraph.input[: len(node.input)] |
| | for i, si in enumerate(subgraph_inputs): |
| | subgraph_name = si.name |
| | si.CopyFrom(self.known_vi_[node.input[i]]) |
| | if i >= num_scan_states: |
| | scan_input_dim = si.type.tensor_type.shape.dim |
| | scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]]) |
| | si.name = subgraph_name |
| | self._onnx_infer_subgraph(node, subgraph) |
| | num_scan_outputs = len(node.output) - num_scan_states |
| | scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs) |
| | scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] |
| | for i, o in enumerate(node.output): |
| | vi = self.known_vi_[o] |
| | if i >= num_scan_states: |
| | shape = get_shape_from_type_proto(subgraph.output[i].type) |
| | new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1) |
| | shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:] |
| | vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape)) |
| | else: |
| | vi.CopyFrom(subgraph.output[i]) |
| | vi.name = o |
| |
|
| | def _infer_ScatterElements(self, node): |
| | data_shape = self._get_shape(node, 0) |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | data_shape, |
| | ) |
| | ) |
| |
|
| | def _infer_SequenceAt(self, node): |
| | |
| | seq_shape = self._get_shape(node, 0) |
| | vi = self.known_vi_[node.output[0]] |
| | if seq_shape is not None: |
| | for di, d in enumerate(seq_shape): |
| | if d is not None: |
| | continue |
| | new_dim = onnx.TensorShapeProto.Dimension() |
| | new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, 0, di)) |
| | vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim) |
| |
|
| | def _infer_SequenceInsert(self, node): |
| | |
| | vi_seq = self.known_vi_[node.input[0]] |
| | vi_tensor = self.known_vi_[node.input[1]] |
| | vi_out_seq = self.known_vi_[node.output[0]] |
| | vi_out_seq.CopyFrom(vi_seq) |
| | vi_out_seq.name = node.output[0] |
| | self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type) |
| |
|
| | def _infer_Shape(self, node): |
| | self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0) |
| |
|
| | def _infer_Size(self, node): |
| | sympy_shape = self._get_sympy_shape(node, 0) |
| | self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape) |
| | self.known_vi_[node.output[0]].CopyFrom( |
| | helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []) |
| | ) |
| |
|
| | def _infer_Slice(self, node): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def flatten_min(expr): |
| | assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}" |
| | min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)] |
| | if len(min_positions) == 1: |
| | min_pos = min_positions[0] |
| |
|
| | def replace_min_with_arg(arg_idx): |
| | replaced = list(expr.args) |
| | assert isinstance( |
| | replaced[min_pos], sympy.Min |
| | ), f"Expected a sympy.Min() at position {min_pos}, got {replaced[min_pos]}" |
| | assert ( |
| | len(replaced[min_pos].args) == 2 |
| | ), f"Expected a sympy.Min() with exactly 2 arguments, got {replaced[min_pos]}" |
| | replaced[min_pos] = replaced[min_pos].args[arg_idx] |
| | return sympy.Add(*replaced) |
| |
|
| | return [ |
| | replace_min_with_arg(0), |
| | replace_min_with_arg(1), |
| | ] |
| | return [expr] |
| |
|
| | def less_equal(x, y): |
| | try: |
| | return bool(x <= y) |
| | except TypeError: |
| | pass |
| | try: |
| | return bool(y >= x) |
| | except TypeError: |
| | pass |
| | try: |
| | return bool(-x >= -y) |
| | except TypeError: |
| | pass |
| | try: |
| | return bool(-y <= -x) |
| | except TypeError: |
| | pass |
| | try: |
| | return bool(y - x >= 0) |
| | except TypeError: |
| | |
| | return all(bool(d >= 0) for d in flatten_min(y - x)) |
| |
|
| | def handle_negative_index(index, bound): |
| | """normalizes a negative index to be in [0, bound)""" |
| | try: |
| | if not less_equal(0, index): |
| | if is_literal(index) and index <= -self.int_max_: |
| | |
| | return index |
| | return bound + index |
| | except TypeError: |
| | logger.warning(f"Cannot determine if {index} < 0") |
| | return index |
| |
|
| | if get_opset(self.out_mp_) <= 9: |
| | axes = get_attribute(node, "axes") |
| | starts = get_attribute(node, "starts") |
| | ends = get_attribute(node, "ends") |
| | if not axes: |
| | axes = list(range(len(starts))) |
| | steps = [1] * len(axes) |
| | else: |
| | starts = as_list(self._try_get_value(node, 1), keep_none=True) |
| | ends = as_list(self._try_get_value(node, 2), keep_none=True) |
| | axes = self._try_get_value(node, 3) |
| | steps = self._try_get_value(node, 4) |
| | if axes is None and not (starts is None and ends is None): |
| | axes = list(range(0, len(starts if starts is not None else ends))) |
| | if steps is None and not (starts is None and ends is None): |
| | steps = [1] * len(starts if starts is not None else ends) |
| | axes = as_list(axes, keep_none=True) |
| | steps = as_list(steps, keep_none=True) |
| |
|
| | new_sympy_shape = self._get_sympy_shape(node, 0) |
| | if starts is None or ends is None: |
| | if axes is None: |
| | for i in range(len(new_sympy_shape)): |
| | new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) |
| | else: |
| | new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape) |
| | for i in axes: |
| | new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i) |
| | else: |
| | for i, s, e, t in zip(axes, starts, ends, steps): |
| | e = handle_negative_index(e, new_sympy_shape[i]) |
| | if is_literal(e): |
| | if e >= self.int_max_: |
| | e = new_sympy_shape[i] |
| | elif e <= -self.int_max_: |
| | e = 0 if s > 0 else -1 |
| | elif is_literal(new_sympy_shape[i]): |
| | if e < 0: |
| | e = max(0, e + new_sympy_shape[i]) |
| | e = min(e, new_sympy_shape[i]) |
| | else: |
| | if e > 0: |
| | e = ( |
| | sympy.Min(e, new_sympy_shape[i]) if e > 1 else e |
| | ) |
| | else: |
| | if is_literal(new_sympy_shape[i]): |
| | e = sympy.Min(e, new_sympy_shape[i]) |
| | else: |
| | try: |
| | if not less_equal(e, new_sympy_shape[i]): |
| | e = new_sympy_shape[i] |
| | except Exception: |
| | logger.warning(f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal") |
| | e = new_sympy_shape[i] |
| |
|
| | s = handle_negative_index(s, new_sympy_shape[i]) |
| | if is_literal(new_sympy_shape[i]) and is_literal(s): |
| | s = max(0, min(s, new_sympy_shape[i])) |
| |
|
| | new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t) |
| |
|
| | self._update_computed_dims(new_sympy_shape) |
| |
|
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | vi.type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(new_sympy_shape), |
| | ) |
| | ) |
| |
|
| | |
| | if ( |
| | node.input[0] in self.sympy_data_ |
| | and [0] == axes |
| | and starts is not None |
| | and len(starts) == 1 |
| | and ends is not None |
| | and len(ends) == 1 |
| | and steps is not None |
| | and len(steps) == 1 |
| | ): |
| | input_sympy_data = self.sympy_data_[node.input[0]] |
| | if type(input_sympy_data) == list or ( |
| | type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1 |
| | ): |
| | self.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]] |
| |
|
| | def _infer_SoftmaxCrossEntropyLoss(self, node): |
| | vi = self.known_vi_[node.output[0]] |
| | elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| |
|
| | |
| | specified_output_type = get_attribute(node, "output_type", None) |
| | if specified_output_type is not None: |
| | elem_type = specified_output_type |
| |
|
| | vi.type.tensor_type.elem_type = elem_type |
| | vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) |
| |
|
| | if len(node.output) > 1: |
| | data_shape = self._get_shape(node, 0) |
| | vi = self.known_vi_[node.output[1]] |
| | vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape)) |
| |
|
| | def _infer_Split_Common(self, node, make_value_info_func): |
| | input_sympy_shape = self._get_sympy_shape(node, 0) |
| | axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) |
| | op_set = get_opset(self.out_mp_) |
| |
|
| | |
| | if op_set < 13: |
| | split = get_attribute(node, "split") |
| | assert self._try_get_value(node, 1) is None |
| | else: |
| | split = self._try_get_value(node, 1) |
| | assert get_attribute(node, "split") is None |
| |
|
| | if split is None: |
| | num_outputs = len(node.output) |
| | split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs |
| | self._update_computed_dims(split) |
| | else: |
| | split = [sympy.Integer(s) for s in split] |
| |
|
| | for i_o in range(len(split)): |
| | vi = self.known_vi_[node.output[i_o]] |
| | vi.CopyFrom( |
| | make_value_info_func( |
| | node.output[i_o], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis + 1 :]), |
| | ) |
| | ) |
| | self.known_vi_[vi.name] = vi |
| |
|
| | def _infer_Split(self, node): |
| | self._infer_Split_Common(node, helper.make_tensor_value_info) |
| |
|
| | def _infer_SplitToSequence(self, node): |
| | self._infer_Split_Common(node, helper.make_sequence_value_info) |
| |
|
| | def _infer_Squeeze(self, node): |
| | input_shape = self._get_shape(node, 0) |
| | op_set = get_opset(self.out_mp_) |
| |
|
| | |
| | if op_set < 13: |
| | axes = get_attribute(node, "axes") |
| | assert self._try_get_value(node, 1) is None |
| | else: |
| | axes = self._try_get_value(node, 1) |
| | assert get_attribute(node, "axes") is None |
| |
|
| | if axes is None: |
| | |
| | |
| | |
| | output_shape = [s for s in input_shape if s != 1] |
| | if self.verbose_ > 0: |
| | symbolic_dimensions = [s for s in input_shape if type(s) != int] |
| | if len(symbolic_dimensions) > 0: |
| | logger.debug( |
| | f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " |
| | f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}" |
| | ) |
| | else: |
| | axes = [handle_negative_axis(a, len(input_shape)) for a in axes] |
| | output_shape = [] |
| | for i in range(len(input_shape)): |
| | if i not in axes: |
| | output_shape.append(input_shape[i]) |
| | else: |
| | assert input_shape[i] == 1 or type(input_shape[i]) != int |
| | if self.verbose_ > 0 and type(input_shape[i]) != int: |
| | logger.debug( |
| | f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " |
| | f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1." |
| | ) |
| |
|
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | output_shape, |
| | ) |
| | ) |
| | self._pass_on_sympy_data(node) |
| |
|
| | def _infer_Tile(self, node): |
| | repeats_value = self._try_get_value(node, 1) |
| | new_sympy_shape = [] |
| | if repeats_value is not None: |
| | input_sympy_shape = self._get_sympy_shape(node, 0) |
| | for i, d in enumerate(input_sympy_shape): |
| | new_dim = d * repeats_value[i] |
| | new_sympy_shape.append(new_dim) |
| | self._update_computed_dims(new_sympy_shape) |
| | else: |
| | new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | vi.type.tensor_type.elem_type, |
| | get_shape_from_sympy_shape(new_sympy_shape), |
| | ) |
| | ) |
| |
|
| | def _infer_TopK(self, node): |
| | rank = self._get_shape_rank(node, 0) |
| | axis = handle_negative_axis(get_attribute(node, "axis", -1), rank) |
| | new_shape = self._get_shape(node, 0) |
| |
|
| | if get_opset(self.out_mp_) <= 9: |
| | k = get_attribute(node, "k") |
| | else: |
| | k = self._get_int_or_float_values(node)[1] |
| |
|
| | if k is None: |
| | k = self._new_symbolic_dim_from_output(node) |
| | else: |
| | k = as_scalar(k) |
| |
|
| | if type(k) in [int, str]: |
| | new_shape[axis] = k |
| | else: |
| | new_sympy_shape = self._get_sympy_shape(node, 0) |
| | new_sympy_shape[axis] = k |
| | self._update_computed_dims( |
| | new_sympy_shape |
| | ) |
| | new_shape = get_shape_from_sympy_shape(new_sympy_shape) |
| |
|
| | for i_o in range(len(node.output)): |
| | vi = self.known_vi_[node.output[i_o]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape)) |
| |
|
| | def _infer_Transpose(self, node): |
| | if node.input[0] in self.sympy_data_: |
| | data_shape = self._get_shape(node, 0) |
| | perm = get_attribute(node, "perm", reversed(list(range(len(data_shape))))) |
| | input_data = self.sympy_data_[node.input[0]] |
| | self.sympy_data_[node.output[0]] = ( |
| | np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist() |
| | ) |
| |
|
| | def _infer_Unsqueeze(self, node): |
| | input_shape = self._get_shape(node, 0) |
| | op_set = get_opset(self.out_mp_) |
| |
|
| | |
| | if op_set < 13: |
| | axes = get_attribute(node, "axes") |
| | assert self._try_get_value(node, 1) is None |
| | else: |
| | axes = self._try_get_value(node, 1) |
| | assert get_attribute(node, "axes") is None |
| |
|
| | output_rank = len(input_shape) + len(axes) |
| | axes = [handle_negative_axis(a, output_rank) for a in axes] |
| |
|
| | input_axis = 0 |
| | output_shape = [] |
| | for i in range(output_rank): |
| | if i in axes: |
| | output_shape.append(1) |
| | else: |
| | output_shape.append(input_shape[input_axis]) |
| | input_axis += 1 |
| |
|
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | node.output[0], |
| | self.known_vi_[node.input[0]].type.tensor_type.elem_type, |
| | output_shape, |
| | ) |
| | ) |
| |
|
| | self._pass_on_sympy_data(node) |
| |
|
| | def _infer_ZipMap(self, node): |
| | map_key_type = None |
| | if get_attribute(node, "classlabels_int64s") is not None: |
| | map_key_type = onnx.TensorProto.INT64 |
| | elif get_attribute(node, "classlabels_strings") is not None: |
| | map_key_type = onnx.TensorProto.STRING |
| |
|
| | assert map_key_type is not None |
| | new_vi = onnx.ValueInfoProto() |
| | new_vi.name = node.output[0] |
| | new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT |
| | new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(new_vi) |
| |
|
| | def _infer_Attention(self, node): |
| | shape = self._get_shape(node, 0) |
| | shape_weights = self._get_shape(node, 1) |
| | shape_bias = self._try_get_shape(node, 2) |
| | if shape_bias is not None: |
| | assert len(shape_bias) == 1 |
| | tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] |
| | if shape and len(shape) == 3: |
| | qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") |
| | if qkv_hidden_sizes_attr is not None: |
| | assert len(qkv_hidden_sizes_attr) == 3 |
| | shape[2] = int(qkv_hidden_sizes_attr[2]) |
| | elif isinstance(tripled_hidden_size, int): |
| | shape[2] = int(tripled_hidden_size / 3) |
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) |
| |
|
| | if len(node.output) > 1: |
| | |
| | |
| | |
| | |
| | input_shape = self._get_shape(node, 0) |
| | past_shape = self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else [] |
| | mask_shape = self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else [] |
| |
|
| | if past_shape and len(past_shape) == 5: |
| | if mask_shape and len(mask_shape) in [2, 3]: |
| | past_shape[3] = mask_shape[-1] |
| | elif input_shape and len(input_shape) == 3: |
| | if isinstance(input_shape[1], int) and isinstance(past_shape[3], int): |
| | past_shape[3] = input_shape[1] + past_shape[3] |
| | else: |
| | past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" |
| | vi = self.known_vi_[node.output[1]] |
| | vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) |
| | |
| | else: |
| | num_heads = get_attribute(node, "num_heads") |
| | head_size = input_shape[2] // num_heads |
| | present_shape = [2, input_shape[0], num_heads, input_shape[1], head_size] |
| | vi = self.known_vi_[node.output[1]] |
| | vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) |
| |
|
| | def _infer_GatedRelativePositionBias(self, node): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | num_heads = get_attribute(node, "num_heads") |
| |
|
| | token_offset_shape = self._try_get_shape(node, 6) |
| | if token_offset_shape is not None: |
| | output_shape = [token_offset_shape[0], num_heads, token_offset_shape[1], token_offset_shape[1]] |
| | else: |
| | query_layer_shape = self._get_shape(node, 0) |
| | assert query_layer_shape is not None and len(query_layer_shape) == 3 |
| | output_shape = [query_layer_shape[0], num_heads, query_layer_shape[1], query_layer_shape[1]] |
| |
|
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
| |
|
| | def _infer_PackedAttention(self, node): |
| | shape = self._get_shape(node, 0) |
| | shape_weights = self._get_shape(node, 1) |
| | shape_bias = self._try_get_shape(node, 2) |
| | if shape_bias is not None: |
| | assert len(shape_bias) == 1 |
| | tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] |
| | if shape and len(shape) == 2: |
| | qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") |
| | if qkv_hidden_sizes_attr is not None: |
| | assert len(qkv_hidden_sizes_attr) == 3 |
| | shape[1] = int(qkv_hidden_sizes_attr[2]) |
| | elif isinstance(tripled_hidden_size, int): |
| | shape[1] = int(tripled_hidden_size / 3) |
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) |
| |
|
| | def _infer_PackedMultiHeadAttention(self, node): |
| | shape_value = self._try_get_shape(node, 2) |
| | if shape_value is not None and len(shape_value) == 2: |
| | output_shape = shape_value |
| | else: |
| | shape_query = self._get_shape(node, 0) |
| | assert shape_query is not None and len(shape_query) == 4 |
| | output_shape = [shape_query[0], shape_query[1] * shape_query[3]] |
| |
|
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
| |
|
| | def _infer_RemovePadding(self, node): |
| | shape = self._get_shape(node, 0) |
| | if shape and len(shape) == 3: |
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, ["token_count", shape[2]])) |
| |
|
| | vi_token_offset = self.known_vi_[node.output[1]] |
| | vi_token_offset.CopyFrom( |
| | helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]]) |
| | ) |
| |
|
| | vi_cumulated_seq_len = self.known_vi_[node.output[2]] |
| | vi_cumulated_seq_len.CopyFrom( |
| | helper.make_tensor_value_info(node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"]) |
| | ) |
| |
|
| | vi_max_seq_len = self.known_vi_[node.output[3]] |
| | vi_max_seq_len.CopyFrom(helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1])) |
| |
|
| | def _infer_RestorePadding(self, node): |
| | shape_input = self._get_shape(node, 0) |
| | shape_token_offset = self._get_shape(node, 1) |
| | if shape_input and len(shape_input) == 2 and shape_token_offset and len(shape_token_offset) == 2: |
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| |
|
| | output_shape = [shape_token_offset[0], shape_token_offset[1], shape_input[1]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
| |
|
| | def _infer_BiasGelu(self, node): |
| | self._propagate_shape_and_type(node) |
| |
|
| | def _infer_MultiHeadAttention(self, node): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | query_shape = self._get_shape(node, 0) |
| | total_sequence_length = None |
| | output_dtype = None |
| | if query_shape is not None: |
| | if len(query_shape) == 3: |
| | key_shape = self._try_get_shape(node, 1) |
| | |
| | output_shape = query_shape |
| | if key_shape is not None and len(key_shape) == 3: |
| | value_shape = self._try_get_shape(node, 2) |
| | if value_shape is not None and len(value_shape) == 3: |
| | output_shape[2] = value_shape[2] |
| | total_sequence_length = key_shape[1] |
| |
|
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
| |
|
| | elif len(query_shape) == 5: |
| | if isinstance(query_shape[2], int) and isinstance(query_shape[4], int): |
| | output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]] |
| | else: |
| | output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"] |
| |
|
| | total_sequence_length = query_shape[1] |
| |
|
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
| |
|
| | if len(node.output) > 1: |
| | batch_size = query_shape[0] |
| | num_heads = get_attribute(node, "num_heads") |
| |
|
| | head_size = None |
| | if len(query_shape) == 3: |
| | head_size = ( |
| | int(query_shape[2] / num_heads) |
| | if isinstance(query_shape[2], int) |
| | else f"{query_shape[2]}/{num_heads}" |
| | ) |
| | else: |
| | head_size = query_shape[4] |
| |
|
| | past_shape = self._try_get_shape(node, 6) |
| |
|
| | if past_shape is not None: |
| | if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int): |
| | total_sequence_length = past_shape[2] + total_sequence_length |
| | else: |
| | total_sequence_length = f"{past_shape[2]}+{total_sequence_length}" |
| |
|
| | present_shape = [batch_size, num_heads, total_sequence_length, head_size] |
| |
|
| | assert output_dtype is not None |
| | if len(node.output) > 2 and node.output[1] and node.output[2]: |
| | vi = self.known_vi_[node.output[1]] |
| | vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) |
| | vi = self.known_vi_[node.output[2]] |
| | vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) |
| |
|
| | def _infer_DecoderMaskedMultiHeadAttention(self, node): |
| | |
| | |
| | |
| | |
| |
|
| | query_shape = self._get_shape(node, 0) |
| | if query_shape is not None: |
| | output_shape = query_shape |
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | assert output_dtype is not None |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) |
| |
|
| | if len(node.output) > 2 and node.output[1] and node.output[2]: |
| | past_shape = self._try_get_shape(node, 5) |
| | if past_shape is not None: |
| | vi = self.known_vi_[node.output[1]] |
| | vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) |
| | vi = self.known_vi_[node.output[2]] |
| | vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) |
| |
|
| | def _infer_FastGelu(self, node): |
| | self._propagate_shape_and_type(node) |
| |
|
| | def _infer_Gelu(self, node): |
| | self._propagate_shape_and_type(node) |
| |
|
| | def _infer_QuickGelu(self, node): |
| | self._propagate_shape_and_type(node) |
| |
|
| | def _infer_GemmFastGelu(self, node): |
| | self._compute_matmul_shape(node) |
| |
|
| | def _infer_GemmFloat8(self, node): |
| | self._compute_matmul_shape(node) |
| |
|
| | def _infer_LayerNormalization(self, node): |
| | self._propagate_shape_and_type(node) |
| | if len(node.output) > 1: |
| | axis = get_attribute(node, "axis") |
| | if axis is None: |
| | axis = -1 |
| | x_shape = self._get_shape(node, 0) |
| | if x_shape is not None: |
| | rank = len(x_shape) |
| | axis = handle_negative_axis(axis, rank) |
| | mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)] |
| | mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | if mean_dtype == onnx.TensorProto.FLOAT16 or mean_dtype == onnx.TensorProto.BFLOAT16: |
| | mean_dtype = onnx.TensorProto.FLOAT |
| | vi = self.known_vi_[node.output[1]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape)) |
| | if len(node.output) > 2: |
| | vi = self.known_vi_[node.output[2]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape)) |
| |
|
| | def _infer_LongformerAttention(self, node): |
| | self._propagate_shape_and_type(node) |
| |
|
| | def _infer_EmbedLayerNormalization(self, node): |
| | input_ids_shape = self._get_shape(node, 0) |
| | word_embedding_shape = self._get_shape(node, 2) |
| | assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2 |
| | output_shape = [*input_ids_shape, word_embedding_shape[1]] |
| |
|
| | word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape)) |
| |
|
| | if len(node.output) > 1 and node.output[1]: |
| | mask_index_shape = [input_ids_shape[0]] |
| | vi = self.known_vi_[node.output[1]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape)) |
| |
|
| | if len(node.output) > 2: |
| | |
| | |
| | vi = self.known_vi_[node.output[2]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape)) |
| |
|
| | def _infer_SkipLayerNormalization(self, node): |
| | self._propagate_shape_and_type(node) |
| |
|
| | |
| | |
| | if len(node.output) > 3: |
| | self._propagate_shape_and_type(node, 0, 3) |
| |
|
| | def _infer_GroupNorm(self, node): |
| | self._propagate_shape_and_type(node) |
| |
|
| | def _infer_PagedAttention(self, node): |
| | self._propagate_shape_and_type(node) |
| |
|
| | def _infer_GroupQueryAttention(self, node): |
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| |
|
| | past_shape = self._try_get_shape(node, 3) |
| | if past_shape is not None: |
| | vi = self.known_vi_[node.output[1]] |
| | vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) |
| | vi = self.known_vi_[node.output[2]] |
| | vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) |
| |
|
| | if node.input[1] != "" and node.input[2] != "": |
| | self._propagate_shape_and_type(node, 0, 0) |
| | else: |
| | |
| | assert node.input[1] == "" and node.input[2] == "" |
| | num_heads = get_attribute(node, "num_heads") |
| | kv_num_heads = get_attribute(node, "kv_num_heads") |
| | query_shape = self._get_shape(node, 0) |
| | if query_shape is not None: |
| | hidden_size = query_shape[2] |
| | if isinstance(hidden_size, int): |
| | head_size = int(hidden_size / (num_heads + 2 * kv_num_heads)) |
| | query_shape[2] = num_heads * head_size |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape)) |
| |
|
| | def _infer_SkipGroupNorm(self, node): |
| | self._propagate_shape_and_type(node, 0, 0) |
| | if len(node.output) > 1: |
| | self._propagate_shape_and_type(node, 0, 1) |
| |
|
| | def _infer_BiasSplitGelu(self, node): |
| | input_shape = self._get_shape(node, 0) |
| | bias_shape = self._get_shape(node, 1) |
| | if input_shape and bias_shape and isinstance(bias_shape[0], int): |
| | output_shape = input_shape |
| | output_shape[2] = int(bias_shape[0] / 2) |
| | vi = self.known_vi_[node.output[0]] |
| | output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) |
| |
|
| | def _infer_BiasAdd(self, node): |
| | self._propagate_shape_and_type(node) |
| |
|
| | def _infer_RotaryEmbedding(self, node): |
| | if len(node.output) == 1: |
| | self._propagate_shape_and_type(node) |
| | elif len(node.output) == 2: |
| | |
| | self._propagate_shape_and_type(node, input_index=1, output_index=0) |
| | self._propagate_shape_and_type(node, input_index=0, output_index=1) |
| | elif len(node.output) == 3: |
| | |
| | self._propagate_shape_and_type(node, input_index=1, output_index=0) |
| | self._propagate_shape_and_type(node, input_index=1, output_index=1) |
| | self._propagate_shape_and_type(node, input_index=0, output_index=2) |
| |
|
| | def _infer_PythonOp(self, node): |
| | output_tensor_types = get_attribute(node, "output_tensor_types") |
| | assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute." |
| | output_tensor_ranks = get_attribute(node, "output_tensor_ranks") |
| | assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute." |
| |
|
| | from onnxruntime.capi._pybind_state import get_shape_inference_function |
| |
|
| | func_name = get_attribute(node, "func_name").decode() |
| | shape_inferer = get_shape_inference_function(func_name) |
| |
|
| | |
| | |
| | vi = self.known_vi_[node.output[0]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])) |
| |
|
| | if shape_inferer is not None: |
| | input_shapes = [] |
| | input_dtypes = [] |
| | for input_index in range(len(node.input)): |
| | shape = self._get_shape(node, input_index) |
| | input_shapes.append(shape) |
| | input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type |
| | input_dtypes.append(input_dtype) |
| | output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes) |
| | assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), ( |
| | f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, " |
| | f"but expected {len(node.output) - 1} outputs." |
| | ) |
| | for i in range(len(node.output) - 1): |
| | output_index = i + 1 |
| | vi = self.known_vi_[node.output[output_index]] |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i]) |
| | ) |
| | else: |
| | |
| | |
| | |
| | for i in range(len(node.output) - 1): |
| | |
| | vi = self.known_vi_[node.output[i + 1]] |
| | sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node) |
| | shape = get_shape_from_sympy_shape(sympy_shape) |
| | value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape) |
| | vi.CopyFrom(value_info) |
| |
|
| | def _propagate_shape_and_type(self, node, input_index=0, output_index=0): |
| | shape = self._get_shape(node, input_index) |
| | output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type |
| | vi = self.known_vi_[node.output[output_index]] |
| | vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape)) |
| |
|
| | def _is_none_dim(self, dim_value): |
| | if type(dim_value) != str: |
| | return False |
| | if "unk__" not in dim_value: |
| | return False |
| | if dim_value in self.symbolic_dims_: |
| | return False |
| | return True |
| |
|
| | def _is_shape_contains_none_dim(self, out_shape): |
| | for out in out_shape: |
| | if self._is_none_dim(out): |
| | return out |
| | return None |
| |
|
| | def _infer_impl(self, start_sympy_data=None): |
| | self.sympy_data_ = start_sympy_data or {} |
| | self.out_mp_.graph.ClearField("value_info") |
| | self._apply_suggested_merge(graph_input_only=True) |
| | self.input_symbols_ = set() |
| | for i in self.out_mp_.graph.input: |
| | input_shape = get_shape_from_value_info(i) |
| | if input_shape is None: |
| | continue |
| |
|
| | if is_sequence(i.type): |
| | input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim |
| | else: |
| | input_dims = i.type.tensor_type.shape.dim |
| |
|
| | for i_dim, dim in enumerate(input_shape): |
| | if dim is None: |
| | |
| | input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim)) |
| |
|
| | self.input_symbols_.update([d for d in input_shape if type(d) == str]) |
| |
|
| | for s in self.input_symbols_: |
| | if s in self.suggested_merge_: |
| | s_merge = self.suggested_merge_[s] |
| | assert s_merge in self.symbolic_dims_ |
| | self.symbolic_dims_[s] = self.symbolic_dims_[s_merge] |
| | else: |
| | |
| | self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True) |
| | |
| | |
| | |
| | self.tmp_mp_ = onnx.ModelProto() |
| | self.tmp_mp_.CopyFrom(self.out_mp_) |
| | self.tmp_mp_.graph.ClearField("initializer") |
| |
|
| | |
| | |
| | prereq_for_node = {} |
| |
|
| | def get_prereq(node): |
| | names = {i for i in node.input if i} |
| | subgraphs = [] |
| | if node.op_type == "If": |
| | subgraphs = [ |
| | get_attribute(node, "then_branch"), |
| | get_attribute(node, "else_branch"), |
| | ] |
| | elif node.op_type in ["Loop", "Scan"]: |
| | subgraphs = [get_attribute(node, "body")] |
| | for g in subgraphs: |
| | g_outputs_and_initializers = {i.name for i in g.initializer} |
| | g_prereq = set() |
| | for n in g.node: |
| | g_outputs_and_initializers.update(n.output) |
| | for n in g.node: |
| | g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers]) |
| | names.update(g_prereq) |
| | |
| | for i in g.input: |
| | if i.name in names: |
| | names.remove(i.name) |
| | return names |
| |
|
| | for n in self.tmp_mp_.graph.node: |
| | prereq_for_node[n.output[0]] = get_prereq(n) |
| |
|
| | |
| | sorted_nodes = [] |
| | sorted_known_vi = {i.name for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer)} |
| | if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): |
| | |
| | sorted_nodes = self.out_mp_.graph.node |
| | else: |
| | while not all([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): |
| | old_sorted_nodes_len = len(sorted_nodes) |
| | for node in self.out_mp_.graph.node: |
| | if (node.output[0] not in sorted_known_vi) and all( |
| | [i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i] |
| | ): |
| | sorted_known_vi.update(node.output) |
| | sorted_nodes.append(node) |
| | if old_sorted_nodes_len == len(sorted_nodes) and not all( |
| | [o.name in sorted_known_vi for o in self.out_mp_.graph.output] |
| | ): |
| | raise Exception("Invalid model with cyclic graph") |
| |
|
| | for node in sorted_nodes: |
| | assert all([i in self.known_vi_ for i in node.input if i]) |
| | self._onnx_infer_single_node(node) |
| | known_aten_op = False |
| | if node.op_type in self.dispatcher_: |
| | self.dispatcher_[node.op_type](node) |
| | elif node.op_type in ["ConvTranspose"]: |
| | |
| | |
| | |
| | vi = self.known_vi_[node.output[0]] |
| | if len(vi.type.tensor_type.shape.dim) == 0: |
| | vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED |
| | elif node.op_type == "ATen" and node.domain == "org.pytorch.aten": |
| | for attr in node.attribute: |
| | |
| | if attr.name == "operator": |
| | aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s |
| | if aten_op_name in self.aten_op_dispatcher_: |
| | known_aten_op = True |
| | self.aten_op_dispatcher_[aten_op_name](node) |
| | break |
| |
|
| | if self.verbose_ > 2: |
| | logger.debug(node.op_type + ": " + node.name) |
| | for i, name in enumerate(node.input): |
| | logger.debug( |
| | " Input {}: {} {}".format(i, name, "initializer" if name in self.initializers_ else "") |
| | ) |
| |
|
| | |
| | |
| | if node.op_type in [ |
| | "Add", |
| | "Sub", |
| | "Mul", |
| | "Div", |
| | "MatMul", |
| | "MatMulInteger", |
| | "MatMulInteger16", |
| | "Where", |
| | "Sum", |
| | ]: |
| | vi = self.known_vi_[node.output[0]] |
| | out_rank = len(get_shape_from_type_proto(vi.type)) |
| | in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] |
| | for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)): |
| | in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] |
| | if len(in_dims) > 1: |
| | self._check_merged_dims(in_dims, allow_broadcast=True) |
| |
|
| | for i_o in range(len(node.output)): |
| | |
| | |
| | |
| | |
| | |
| | if ( |
| | node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization" |
| | ) and i_o in [1, 2]: |
| | continue |
| | if node.op_type == "RotaryEmbedding" and len(node.output) > 1: |
| | |
| | |
| | continue |
| |
|
| | vi = self.known_vi_[node.output[i_o]] |
| | out_type = vi.type |
| | out_type_kind = out_type.WhichOneof("value") |
| |
|
| | |
| | if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]: |
| | if self.verbose_ > 2: |
| | if out_type_kind == "sequence_type": |
| | seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value") |
| | if seq_cls_type == "tensor_type": |
| | logger.debug( |
| | " {}: sequence of {} {}".format( |
| | node.output[i_o], |
| | str(get_shape_from_value_info(vi)), |
| | onnx.TensorProto.DataType.Name( |
| | vi.type.sequence_type.elem_type.tensor_type.elem_type |
| | ), |
| | ) |
| | ) |
| | else: |
| | logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}") |
| | else: |
| | logger.debug(f" {node.output[i_o]}: {out_type_kind}") |
| | continue |
| |
|
| | out_shape = get_shape_from_value_info(vi) |
| | out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED |
| | if self.verbose_ > 2: |
| | logger.debug( |
| | " {}: {} {}".format( |
| | node.output[i_o], |
| | str(out_shape), |
| | onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type), |
| | ) |
| | ) |
| | if node.output[i_o] in self.sympy_data_: |
| | logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]])) |
| |
|
| | |
| | if ( |
| | out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape)) |
| | ) or out_type_undefined: |
| | if self.auto_merge_: |
| | if node.op_type in [ |
| | "Add", |
| | "Sub", |
| | "Mul", |
| | "Div", |
| | "MatMul", |
| | "MatMulInteger", |
| | "MatMulInteger16", |
| | "Concat", |
| | "Where", |
| | "Sum", |
| | "Equal", |
| | "Less", |
| | "Greater", |
| | "LessOrEqual", |
| | "GreaterOrEqual", |
| | "Min", |
| | "Max", |
| | ]: |
| | shapes = [self._get_shape(node, i) for i in range(len(node.input))] |
| | if node.op_type in [ |
| | "MatMul", |
| | "MatMulInteger", |
| | "MatMulInteger16", |
| | ]: |
| | if None in out_shape or self._is_shape_contains_none_dim(out_shape): |
| | if None in out_shape: |
| | idx = out_shape.index(None) |
| | else: |
| | idx = out_shape.index(self._is_shape_contains_none_dim(out_shape)) |
| | dim_idx = [len(s) - len(out_shape) + idx for s in shapes] |
| | |
| | assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2 |
| | assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2 |
| | elif node.op_type == "Expand": |
| | |
| | shapes = [ |
| | self._get_shape(node, 0), |
| | self._get_value(node, 1), |
| | ] |
| | else: |
| | shapes = [] |
| |
|
| | if shapes: |
| | for idx in range(len(out_shape)): |
| | if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]): |
| | continue |
| | |
| | |
| | dim_idx = [len(s) - len(out_shape) + idx for s in shapes] |
| | if len(dim_idx) > 0: |
| | self._add_suggested_merge( |
| | [ |
| | s[i] if is_literal(s[i]) else str(s[i]) |
| | for s, i in zip(shapes, dim_idx) |
| | if i >= 0 |
| | ] |
| | ) |
| | self.run_ = True |
| | else: |
| | self.run_ = False |
| | else: |
| | self.run_ = False |
| |
|
| | |
| | if self.run_ is False and node.op_type not in self.dispatcher_ and not known_aten_op: |
| | is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0) |
| | if is_unknown_op: |
| | |
| | |
| | out_rank = self._get_shape_rank(node, 0) if self.guess_output_rank_ else -1 |
| | else: |
| | |
| | out_rank = len(out_shape) |
| |
|
| | if out_rank >= 0: |
| | new_shape = self._new_symbolic_shape(out_rank, node, i_o) |
| | if out_type_undefined: |
| | |
| | out_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type |
| | else: |
| | |
| | out_dtype = vi.type.tensor_type.elem_type |
| | vi.CopyFrom( |
| | helper.make_tensor_value_info( |
| | vi.name, |
| | out_dtype, |
| | get_shape_from_sympy_shape(new_shape), |
| | ) |
| | ) |
| |
|
| | if self.verbose_ > 0: |
| | if is_unknown_op: |
| | logger.debug( |
| | "Possible unknown op: {} node: {}, guessing {} shape".format( |
| | node.op_type, node.name, vi.name |
| | ) |
| | ) |
| | if self.verbose_ > 2: |
| | logger.debug( |
| | " {}: {} {}".format( |
| | node.output[i_o], |
| | str(new_shape), |
| | vi.type.tensor_type.elem_type, |
| | ) |
| | ) |
| |
|
| | self.run_ = True |
| | continue |
| |
|
| | if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: |
| | logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name) |
| | logger.debug("node inputs:") |
| | for i in node.input: |
| | if i in self.known_vi_: |
| | logger.debug(self.known_vi_[i]) |
| | else: |
| | logger.debug(f"not in known_vi_ for {i}") |
| | logger.debug("node outputs:") |
| | for o in node.output: |
| | if o in self.known_vi_: |
| | logger.debug(self.known_vi_[o]) |
| | else: |
| | logger.debug(f"not in known_vi_ for {o}") |
| | if self.auto_merge_ and not out_type_undefined: |
| | logger.debug("Merging: " + str(self.suggested_merge_)) |
| | return False |
| |
|
| | self.run_ = False |
| | return True |
| |
|
| | def _update_output_from_vi(self): |
| | for output in self.out_mp_.graph.output: |
| | if output.name in self.known_vi_: |
| | output.CopyFrom(self.known_vi_[output.name]) |
| |
|
| | @staticmethod |
| | def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0): |
| | onnx_opset = get_opset(in_mp) |
| | if (not onnx_opset) or onnx_opset < 7: |
| | logger.warning("Only support models of onnx opset 7 and above.") |
| | return None |
| | symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose) |
| | all_shapes_inferred = False |
| | symbolic_shape_inference._preprocess(in_mp) |
| | while symbolic_shape_inference.run_: |
| | all_shapes_inferred = symbolic_shape_inference._infer_impl() |
| | symbolic_shape_inference._update_output_from_vi() |
| | if not all_shapes_inferred: |
| | onnx.save_model(symbolic_shape_inference.out_mp_, "sym_shape_infer_temp.onnx", save_as_external_data=True) |
| | raise Exception("Incomplete symbolic shape inference") |
| | return symbolic_shape_inference.out_mp_ |
| |
|
| |
|
| | def parse_arguments(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--input", required=True, help="The input model file") |
| | parser.add_argument("--output", help="The output model file") |
| | parser.add_argument( |
| | "--auto_merge", |
| | help="Automatically merge symbolic dims when confliction happens", |
| | action="store_true", |
| | default=False, |
| | ) |
| | parser.add_argument( |
| | "--int_max", |
| | help="maximum value for integer to be treated as boundless for ops like slice", |
| | type=int, |
| | default=2**31 - 1, |
| | ) |
| | parser.add_argument( |
| | "--guess_output_rank", |
| | help="guess output rank to be the same as input 0 for unknown ops", |
| | action="store_true", |
| | default=False, |
| | ) |
| | parser.add_argument( |
| | "--verbose", |
| | help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed", |
| | type=int, |
| | default=0, |
| | ) |
| | parser.add_argument( |
| | "--save_as_external_data", |
| | help="Saving an ONNX model to external data", |
| | action="store_true", |
| | default=False, |
| | ) |
| | parser.add_argument( |
| | "--all_tensors_to_one_file", |
| | help="Saving all the external data to one file", |
| | action="store_true", |
| | default=False, |
| | ) |
| | parser.add_argument( |
| | "--external_data_location", |
| | help="The file location to save the external file", |
| | default="./", |
| | ) |
| | parser.add_argument( |
| | "--external_data_size_threshold", |
| | help="The size threshold for external data", |
| | type=int, |
| | default=1024, |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_arguments() |
| | logger.info("input model: " + args.input) |
| | if args.output: |
| | logger.info("output model " + args.output) |
| | logger.info("Doing symbolic shape inference...") |
| | out_mp = SymbolicShapeInference.infer_shapes( |
| | onnx.load(args.input), |
| | args.int_max, |
| | args.auto_merge, |
| | args.guess_output_rank, |
| | args.verbose, |
| | ) |
| | if args.output and out_mp: |
| | if args.save_as_external_data: |
| | onnx.save_model( |
| | out_mp, |
| | args.output, |
| | save_as_external_data=True, |
| | all_tensors_to_one_file=args.all_tensors_to_one_file, |
| | location=args.external_data_location, |
| | size_threshold=args.external_data_size_threshold, |
| | convert_attribute=False, |
| | ) |
| | else: |
| | onnx.save(out_mp, args.output) |
| | logger.info("Done!") |
| |
|