-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor[next]: embedded with itir.Program #1530
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Till Ehrengruber <t.ehrengruber@me.com>
Co-authored-by: Till Ehrengruber <t.ehrengruber@me.com>
src/gt4py/next/iterator/ir.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a question, not related to this PR but in general to the new IR and specifically to this type:
class SetAt(Stmt): # from JAX array.at[...].set()
expr: Expr # only `as_fieldop(stencil)(inp0, ...)` in first refactoring
domain: Expr
target: Expr # `make_tuple` or SymRef
Do we really need to support make_tuple
as target expression?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's the current representation for something like
@fundef
def prog(a,b,c)
setat(as_field_op(lambda x: make_tuple(deref(x)+1, deref(x)+2))(c), domain, make_tuple(a,b))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the fieldview backend support that representation as is, I mean with inlined make_tuple
? or just this one?
@fundef
def prog(a,b,c)
setat(as_field_op(lambda x, y: make_tuple(x, y))(as_field_op(lambda x: x+1)(c), as_field_op(lambda x: x+2)(c), ), domain, make_tuple(a,b))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am asking because from a lowering perspective, make_tuple
and tuple_get
do not implement any kind of computation on fields. Therefore, it is difficult to represent them in my map-tasklet graph. Would it be too strange to treat these builtins on the same level as as_field_op
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fundef
def prog(a,b,c)
setat(make_tuple(as_field_op(lambda x: x+1)(c), as_field_op(lambda x: x+2)(c)), domain, make_tuple(a,b))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First round.
def field_from_typespec( | ||
domain: common.Domain, xp: ModuleType | ||
) -> Callable[..., common.MutableField | tuple[common.MutableField | tuple, ...]]: | ||
@utils.tree_map(collection_type=ts.TupleType, result_collection_type=tuple) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also useful to users (we have a very similar function in PMAP https://github.com/PMAP-Project/PMAP-G/blob/68894a28265d30934ab720373ef60532f2776c84/src/gfvm/model/state_container.py#L22), let's add a docstring + doctest. Additionally I would prefer to make this a classmethod of NDArrayField
something like:
NDArrayField.from_type_spec()
. This would be analogous to NDArrayField.from_array
.
return not a | ||
|
||
|
||
@builtins.gamma.register(EMBEDDED) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not worth spending a lot of time here, but since Colum implements __array_ufunc__
I am surprised this special handling is required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gamma is not a numpy function
@@ -210,6 +210,7 @@ def __bool__(self): | |||
class TracerContext: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the TracerContext need to support both closures and statements? Because the tests are written in the old way and we can change them until we updated the passes? Let's add a short todo with explanation here.
@@ -241,6 +247,11 @@ def closure(domain, stencil, output, inputs): | |||
) | |||
|
|||
|
|||
@iterator.runtime.set_at.register(TRACING) | |||
def set_at(expr, domain, target): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type annotations missing.
) -> Callable[[Callable[_P, _R]], Callable[..., _R | tuple[_R | tuple, ...]]]: ... | ||
|
||
|
||
def tree_map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is time to remove apply_to_primitive_constituents
and absorb it in this function. I'll check if @SF-N has capacity to work on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sara will extract this into a new PR and also remove the apply_to_primitive_constituents
cases there. I'll keep you in the loop.
tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py
Outdated
Show resolved
Hide resolved
cscs-ci run |
Updates itir.embedded to work with
itir.Progam
s, i.e.set_at
andas_fieldop
.For programs to be able to run in embedded, the domain needs to be provided as second argument to
as_fieldop
.Introduces a
DimensionKind
toitir.AxisLiteral
to be able to reconstruct the kind from the IR. This is needed now as theset_at
assigns from field to field, which requires matching dimensions. However, previously the python program generated from IR would always construct horizontal dimensions (but the information would not be used).