Skip to content

Commit

Permalink
Allow ReferenceEvaluator to return intermediate results (onnx#6066)
Browse files Browse the repository at this point in the history
### Description

Intermediate results can only be printed right now. With this PR, they
can be returned as well.

### Motivation and Context
See onnx#6025.

---------

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
  • Loading branch information
xadupre authored and gramalingam committed Apr 12, 2024
1 parent 92d0ed4 commit 32fdc58
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 17 deletions.
17 changes: 15 additions & 2 deletions onnx/reference/reference_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,17 +542,27 @@ def _load_impl( # noqa: PLR0911
f"is unknown, known functions: {sorted(self.functions_)}."
)

def run(self, output_names, feed_inputs: Dict[str, Any], attributes: Optional[Dict[str, Any]] = None): # type: ignore
def run(
self,
output_names,
feed_inputs: Dict[str, Any],
attributes: Optional[Dict[str, Any]] = None,
intermediate: bool = False,
) -> Union[Dict[str, Any], List[Any]]: # type: ignore
"""Executes the onnx model.
Args:
output_names: requested outputs by names, None for all
feed_inputs: dictionary `{ input name: input value }`
attributes: attributes value if the instance runs a
FunctionProto
intermediate: if True, the function returns all the results,
final ones and intermediates one in a same dictionary,
if False, only the final results are returned in a list
Returns:
list of requested outputs
list of requested outputs if intermediate is False,
named results in a dictionary otherwise
"""
if output_names is None:
output_names = self.output_names
Expand Down Expand Up @@ -591,6 +601,9 @@ def run(self, output_names, feed_inputs: Dict[str, Any], attributes: Optional[Di
results[name] = value

# return the results
if intermediate:
return results

for name in output_names:
if name not in results:
raise RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion onnx/test/model_container_refeval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def common_check_reference_evaluator(self, container):
],
dtype=np.float32,
)
npt.assert_allclose(expected, got[0])
npt.assert_allclose(expected, got[0]) # type: ignore[index]

def test_large_onnx_no_large_initializer(self):
model_proto = _linear_regression()
Expand Down
15 changes: 15 additions & 0 deletions onnx/test/reference_evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,21 @@ def test_reference_evaluator_no_attribute(self):
expected = (x + y) * (y - z)
assert_allclose(expected, res)

def test_reference_evaluator_no_attribute_intermediate(self):
m = TestReferenceEvaluator._load_model(TestReferenceEvaluator.m2_def)
checker.check_model(m)
sess = ReferenceEvaluator(m)
self.assertEqual(sess.input_names, ["B01", "B11", "B21"])
self.assertEqual(sess.output_names, ["D0"])
self.assertEqual(sess.opsets, {"": 10, "com.microsoft": 1})
x = np.array([[0, 1], [2, 3]], dtype=np.float32)
y = np.array([[4, 5], [6, 7]], dtype=np.float32)
z = np.array([[-4, -5], [-6, -7]], dtype=np.float32)
res = sess.run(None, {"B01": x, "B11": y, "B21": z}, intermediate=True)
self.assertIsInstance(res, dict)
expected = (x + y) * (y - z)
assert_allclose(expected, res["D0"])

def test_reference_evaluator_no_attribute_bytes(self):
m = TestReferenceEvaluator._load_model(TestReferenceEvaluator.m2_def)
checker.check_model(m)
Expand Down
28 changes: 14 additions & 14 deletions onnx/test/tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def test_replace_initializer(self):

x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
oinf1 = ReferenceEvaluator(model_def)
y1 = oinf1.run(None, {"X": x})[0]
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(model_def)
node_types = {n.op_type for n in repl.graph.node}
self.assertIn("ConstantOfShape", node_types)
oinf2 = ReferenceEvaluator(repl)
y1[:, :] = 3.5
y1[0, :] = 0.5
y2 = oinf2.run(None, {"X": x})[0]
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
assert_allclose(y1, y2)

def test_replace_constant(self):
Expand All @@ -101,14 +101,14 @@ def test_replace_constant(self):

x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
oinf1 = ReferenceEvaluator(model_def)
y1 = oinf1.run(None, {"X": x})[0]
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(model_def)
node_types = {n.op_type for n in repl.graph.node}
self.assertIn("ConstantOfShape", node_types)
oinf2 = ReferenceEvaluator(repl)
y1[:, :] = 3.5
y1[0, :] = 0.5
y2 = oinf2.run(None, {"X": x})[0]
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
assert_allclose(y1, y2)

def test_replace_range(self):
Expand All @@ -128,13 +128,13 @@ def test_replace_range(self):

x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
oinf1 = ReferenceEvaluator(model_def)
y1 = oinf1.run(None, {"X": x})[0]
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(model_def, use_range=True)
node_types = {n.op_type for n in repl.graph.node}
self.assertIn("Range", node_types)
self.assertNotIn("ConstantOfShape", node_types)
oinf2 = ReferenceEvaluator(repl)
y2 = oinf2.run(None, {"X": x})[0]
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
assert_allclose(y1.shape, y2.shape)

def test_replace_constant_function(self):
Expand Down Expand Up @@ -171,14 +171,14 @@ def test_replace_constant_function(self):

x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
oinf1 = ReferenceEvaluator(model_def)
y1 = oinf1.run(None, {"X": x})[0]
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(model_def)
node_types = {n.op_type for n in repl.functions[0].node}
self.assertIn("ConstantOfShape", node_types)
oinf2 = ReferenceEvaluator(repl)
y1[:, :] = 3.5
y1[0, :] = 0.5
y2 = oinf2.run(None, {"X": x})[0]
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
assert_allclose(y1, y2)

def test_replace_range_function(self):
Expand Down Expand Up @@ -215,13 +215,13 @@ def test_replace_range_function(self):

x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
oinf1 = ReferenceEvaluator(model_def)
y1 = oinf1.run(None, {"X": x})[0]
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(model_def, use_range=True)
node_types = {n.op_type for n in repl.functions[0].node}
self.assertIn("Range", node_types)
self.assertNotIn("ConstantOfShape", node_types)
oinf2 = ReferenceEvaluator(repl)
y2 = oinf2.run(None, {"X": x})[0]
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
assert_allclose(y1.shape, y2.shape)

def test_replace_constant_graph(self):
Expand Down Expand Up @@ -264,11 +264,11 @@ def test_replace_constant_graph(self):

x = np.ones((3, 2), dtype=np.float32)
oinf1 = ReferenceEvaluator(onnx_model)
y1 = oinf1.run(None, {"X": x})[0]
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(onnx_model)
self.assertIn("ConstantOfShape", str(repl))
oinf2 = ReferenceEvaluator(repl)
y2 = oinf2.run(None, {"X": x})[0]
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
y1 = y1.copy()
y1[:] = 0.5
assert_allclose(y1, y2)
Expand Down Expand Up @@ -313,12 +313,12 @@ def test_replace_range_graph(self):

x = np.ones((3, 2), dtype=np.float32)
oinf1 = ReferenceEvaluator(onnx_model)
y1 = oinf1.run(None, {"X": x})[0]
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(onnx_model, use_range=True)
self.assertNotIn("ConstantOfShape", str(repl))
self.assertIn("Range", str(repl))
oinf2 = ReferenceEvaluator(repl)
y2 = oinf2.run(None, {"X": x})[0]
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
assert_allclose(y1.shape, y2.shape)


Expand Down

0 comments on commit 32fdc58

Please sign in to comment.