Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: as_offset implementation in embedded #1397

Open
wants to merge 45 commits into
base: main
Choose a base branch
from

Conversation

nfarabullini
Copy link
Contributor

@nfarabullini nfarabullini commented Dec 13, 2023

implementation of as_offset in embedded

@nfarabullini nfarabullini marked this pull request as draft December 13, 2023 12:54
@nfarabullini nfarabullini changed the title as_offset implementation in embedded feat[next]: as_offset implementation in embedded Dec 13, 2023
src/gt4py/next/embedded/nd_array_field.py Outdated Show resolved Hide resolved
src/gt4py/next/embedded/nd_array_field.py Outdated Show resolved Hide resolved
@havogt
Copy link
Contributor

havogt commented Jan 5, 2024

For later reference:
Here a proposal for our multidimensional version of take data-apis/array-api#669 (if I understand correctly, but currently there is only this https://data-apis.org/array-api/latest/API_specification/generated/array_api.take.html.

@nfarabullini nfarabullini marked this pull request as ready for review January 5, 2024 13:18
@edopao
Copy link
Contributor

edopao commented Jan 17, 2024

cscs-ci run

Copy link
Contributor

@havogt havogt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, still a few comments.

src/gt4py/next/common.py Outdated Show resolved Hide resolved
src/gt4py/next/embedded/nd_array_field.py Outdated Show resolved Hide resolved
src/gt4py/next/embedded/nd_array_field.py Outdated Show resolved Hide resolved
@@ -198,8 +216,11 @@ def remap(
# then compute the index array
xp = self.array_ns
new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start
# finally, take the new array
new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx)
if self._ndarray.ndim > 1 and restricted_connectivity_domain == new_domain:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the second part of this condition? restricted_connectivity_domain == new_domain

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid entering this condition in cases like:

    @gtx.field_operator
    def testee(a: gtx.Field[[Vertex, KDim], float]) -> gtx.Field[[Edge, KDim], float]:
        return a(E2V[0])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this if else branch and are you sure all cases are handled? I am confused...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using FieldOffsets, only the specific dimensions related to the offset are taken into account.
Say I have this field_operator:

    @gtx.field_operator
    def testee(a: gtx.Field[[Edge, KDim], int]) -> gtx.Field[[Vertex, KDim], int]:
        tmp = neighbor_sum(a(V2E), axis=V2EDim)
        return tmp

Here the restricted_connectivity_domain will be over [Edge, V2E] and will exclude KDim. In this case using the regular xp.take works.

When using as_offset, xp.take is also ok to use if the offset_field contains only one dimension.

However, when restricted_connectivity_domain contains multiple dimensions that are exactly the same as in new_domain, we have seen that xp.take does not work and hence had to create _take_mdim

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what about restricted_connectivity_domain.dims == new_domain.dims, but ranges are different?

src/gt4py/next/ffront/fbuiltins.py Outdated Show resolved Hide resolved
src/gt4py/next/embedded/nd_array_field.py Outdated Show resolved Hide resolved
src/gt4py/next/embedded/nd_array_field.py Outdated Show resolved Hide resolved
src/gt4py/next/embedded/nd_array_field.py Outdated Show resolved Hide resolved
src/gt4py/next/ffront/experimental.py Outdated Show resolved Hide resolved
@nfarabullini
Copy link
Contributor Author

cscs-ci run

@nfarabullini
Copy link
Contributor Author

cscs-ci run

@nfarabullini
Copy link
Contributor Author

cscs-ci run

offset_provider={"Ioff": IDim, "Koff": KDim},
ref=a[2:],
comparison=lambda out, ref: np.all(out == ref),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed we might be missing another test case: let's say we have a 3D field but the offset field is only 2D. I think the expected semantic is probably as if the offset field would be broadcasted first. This might be related to my comment about

if self._ndarray.ndim > 1 and restricted_connectivity_domain == new_domain:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants