Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aten.as_strided converter #2735

Merged
merged 4 commits into from May 22, 2024
Merged

Conversation

chohk88
Copy link
Collaborator

@chohk88 chohk88 commented Apr 9, 2024

Description

New feature to support aten.as_strided converter converter. Our implementation focuses on accurately calculating the indices required

Fixes # (issue)

Example Usage

Given a tensor x:

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

Computing indices_tensor

Example 1: Stride=[1, 2]

  • Operation: torch.as_strided(x, (2, 2), (1, 2), 0)
  • indices_tensor: [0, 2, 1, 3] from stride [1, 2].
  • Output: [[1, 3], [2, 4]]

Example 2: Stride=[2, 1]

  • Operation: torch.as_strided(x, (2, 2), (2, 1), 0)
  • indices_tensor: [0, 1, 2, 3] from stride [2, 1].
  • Output: [[1, 2], [3, 4]]

as_strided converter calculates indices_tensor using size and stride, then reshapes the tensor accordingly

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 9, 2024
@github-actions github-actions bot requested a review from peri044 April 9, 2024 09:59
@chohk88 chohk88 requested review from zewenli98, apbose and gs-olive and removed request for peri044 April 9, 2024 09:59
@chohk88 chohk88 linked an issue Apr 9, 2024 that may be closed by this pull request
Comment on lines +285 to +298
# Recursive function to compute indices for as_strided operation
def nested(
rank: int, size: Sequence[int], stride: Sequence[int], current: int, dim: int
) -> None:
if (
dim == rank
): # If the current dimension equals the rank, append the computed index
indices.append(current)
return
for i in range(size[dim]): # Recursively compute indices across dimensions
nested(
rank, size, stride, current + stride[dim] * i, dim + 1
) # Calculate the index for the current dimension and recursively explore further dimensions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function appears to only perform work when dim == rank.

Would it be possible to instead replace with an action that only occurs when dim == rank, for instance something like:

indices = [(storage_offset + stride[-1] * i) for i in range(size[-1])]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review and the suggestion to simplify the function.

The proposed list comprehension indeed simplifies the index computation for cases where we are only dealing with the last dimension. However, the recursive function was designed to handle multi-dimensional tensors where each dimension could potentially affect the computed index due to its stride and size.

rank, size, stride, current + stride[dim] * i, dim + 1
) # Calculate the index for the current dimension and recursively explore further dimensions

nested(len(size), size, stride, storage_offset, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment


nested(len(size), size, stride, storage_offset, 0)

indices = torch.tensor(indices, dtype=torch.int)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For torch.compile, it is preferable to use np.array here. The subsequent get_trt_tensor call can handle that as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review! I've updated the code based on your suggestions.

stride: Sequence[int],
storage_offset: int,
) -> TRTTensor:
assert len(size) == len(stride), "size and stride shapes must be the same"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed if it is required for Torch to pass. If this condition causes a failure in Torch, it is not required here, since Torch will catch it in the original model when attempting inference.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestions! I have now made the adjustments to the code.

ctx,
target,
source_ir,
f"{name}_reshape",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate with the name above on line 280 - needs to be distinct, perhaps f"{name}_reshape_gather_output".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review! I've updated the code based on your suggestions. And also I have changed the name of flatten_output layer.

input=args[0],
size=args[1],
stride=args[2],
storage_offset=args_bounds_check(args, 3, 0),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the schema, storage_offset will be None if the arg is omitted. Thus, please change to storage_offset=args_bounds_check(args, 3, None),

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestions! I've made the changes to the code. I also made sure that storage_offset is an integer before using it in the nested function.

input: TRTTensor,
size: Sequence[int],
stride: Sequence[int],
storage_offset: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the schema above, here should be storage_offset: Optional[int]. Then, the following code should be able to deal with storage_offset=None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate your detailed reviews. I've made the changes to the code.


indices = torch.tensor(indices, dtype=torch.int)

indices_tensor = get_trt_tensor(ctx, (indices), f"{name}_indices")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of (indices), is this bracket necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I have removed unnecessary bracket.

(9, 9, 3, 2),
(2, 2, 2, 3),
1,
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more tests that size or stride contains 0, and/or the input is not a square/cube?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestions! I've added more test cases that include scenarios where the stride contains 0 and the input is not a square/cube. However, I encountered issues as we don't support output tensors with zero-sized dimensions for this operation. Therefore, I've added a validator to handle such cases.

I have a question regarding when the output shape contains zero. For instance, when using torch.as_strided(x, (0, 1), (2, 1), 0), it returns a tensor with a zero-sized dimension.
image

I attempted to implement functionality to return an empty tensor with the same shape, as shown below:

if 0 in size:
    empty_out = np.empty(size, dtype=np.float64)
    empty_out = get_trt_tensor(ctx, empty_out, f"{name}_empty")
    print(empty_out.dtype)
    return empty_out

# or

if 0 in size:
    empty_out = np.zeros(size, dtype=np.float64)
    empty_out = get_trt_tensor(ctx, empty_out, f"{name}_empty")
    print(empty_out.dtype)
    return empty_out

However, this approach leads to the following errors:

INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 134, GPU 1262 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +1750, GPU +289, now: CPU 2019, GPU 1551 (MiB)
ERROR:torch_tensorrt [TensorRT Conversion Context]:3: [constantLayer.cpp::setWeights::40] Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/layers/constantLayer.cpp::setWeights::40, condition: !weights.values == !weights.count )
DataType.???
ERROR:torch_tensorrt [TensorRT Conversion Context]:2: [graphShapeAnalyzer.cpp::checkCalculationStatusSanity::1607] Error Code 2: Internal Error (Assertion !isPartialWork(p.second.outputExtents) failed. )
ERROR:torch_tensorrt [TensorRT Conversion Context]:2: [graphShapeAnalyzer.cpp::needTypeAndDimensions::2270] Error Code 2: Internal Error (Assertion !isPartialWork(status.outputExtents) failed. )
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:01.714407
ERROR:torch_tensorrt [TensorRT Conversion Context]:2: [graphShapeAnalyzer.cpp::needTypeAndDimensions::2270] Error Code 2: Internal Error (Assertion !isPartialWork(status.outputExtents) failed. )
F

Do you have any suggestions on how to handle returning tensors with an output size containing zero?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, it looks like the output doesn't support empty in torch-trt. I think it's ok to add a validator like you have done, but need to remove the duplicated decorator: https://github.com/pytorch/TensorRT/pull/2735/files#diff-9cbf43cfb62cdf682eef77f6fd5cabc488bb5a91ff09bfee5d593827542b2476R2203

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2024-05-22 03:55:21.086751+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py	2024-05-22 03:57:25.753026+00:00
@@ -2217,10 +2217,11 @@
        input=args[0],
        weight=args[1],
        bias=args_bounds_check(args, 2, None),
    )

+
@dynamo_tensorrt_converter(torch.ops.aten._cdist_forward.default)
def aten_ops_cdist_forward(
    ctx: ConversionContext,
    target: Target,
    args: Tuple[Argument, ...],

@chohk88 chohk88 merged commit 9341e9b into main May 22, 2024
36 of 51 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests needs-release-cherrypick
Projects
None yet
Development

Successfully merging this pull request may close these issues.

aten.as_strided
5 participants