Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding prod & prod_dim #1173

Closed
wants to merge 12 commits into from
Closed

Conversation

unrenormalizable
Copy link
Contributor

Pull Request Template

NOTE:

  1. Following things still need to be done, creating this PR to get guidance on those. Go to the comments section of this PR and I point these out in their respective places.
  2. tch, ndarray & candle dont support prod_axis required for this implementation. QUESTION should I implement it in the burn layer with some gymnastics or leave it as panic!() like e.g. burn-candle's int_div_scalar
  3. there is large copy-pasta from sum implementation
  4. was expecting to write some more directed tests esp. in burn-wgpu but couldn't find tests for other ops. Am I missing something or do we expect the ops to be tested by the higher layers?

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

n/a

Changes

Adds prod & prod_dim in preparation for adding ReduceProd ONNX op. Really a most useless checkin, doing it for getting familiar with the codebase + fun.

Testing

run-checks all

@@ -311,6 +311,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}

fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tch, ndarray & candle dont support prod_axis required for this implementation.

QUESTION should I implement it in the burn layer with some gymnastics or leave it as panic!() like e.g. burn-candle's int_div_scalar

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's rather bad to offer an operation in the API if it's gonna fail on 75% of our backends. I think if you can work out some default implementation in the burn layer it would be cool, otherwise we should not even offer the operation.
Candle's int_div_scalar is an exception that we should fix eventually.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i agree. i will file issues on tch, ndarray and candle and implement a default.

@@ -0,0 +1,116 @@
use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a large copy-pasta from sum implementation

what is the recommendation here? parametrize the core implementation and call them with sum / prod arguments?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we'll want to refactor this so that all reduce operations share some core. The list of reduction operations is growing (there's also this PR #1136 coming up with more autotuned reduce) so we can't afford to have this many duplicates. For now you can keep it as is but I'll open an issue for refactoring this.

@unrenormalizable unrenormalizable changed the base branch from main to book/guide January 25, 2024 04:37
@unrenormalizable unrenormalizable changed the base branch from book/guide to main January 25, 2024 04:37
Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @unrenormalizable
Thanks for the PR draft, it's looking good. See my comments

(WORKGROUP_DEFAULT * WORKGROUP_DEFAULT).to_string(),
)
.register("initial", 0.0.to_string())
.register("update", "shared_memory[local_id] += value; ")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you'll want initial to be 1.0 and the update to be *=

@@ -0,0 +1,116 @@
use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we'll want to refactor this so that all reduce operations share some core. The list of reduction operations is growing (there's also this PR #1136 coming up with more autotuned reduce) so we can't afford to have this many duplicates. For now you can keep it as is but I'll open an issue for refactoring this.

workgroupBarrier();

if id_local == 0u {
var prod = {{ elem }}(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should start at 1.0 or it won't compute much ;)

@@ -311,6 +311,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
CandleTensor::new(tensor.tensor.sum_keepdim(dim).unwrap())
}

fn int_prod<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
todo!();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's rather bad to offer an operation in the API if it's gonna fail on 75% of our backends. I think if you can work out some default implementation in the burn layer it would be cool, otherwise we should not even offer the operation.
Candle's int_div_scalar is an exception that we should fix eventually.

let ones = B::ones(shape, &B::device(&grad));
let grad = B::prod_dim(grad, dim);

B::mul(ones, grad)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why multiply by one?


unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let ones = B::ones(shape, &B::device(&grad));
let grad = B::prod_dim(grad, dim);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an example, I think the derivative of prod_dim([a, b, c, d], 0) should be
[bcd, acd, abd, bcd], then multiply by the unchanged grad. So you'll need to register the original input as state. Don't hesitate to call me out if you think I'm wrong I haven't thought of it too long

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i am stupid. i will add unit tests for this.

@antimora
Copy link
Collaborator

Great! Let me know when you start working on ONNX part.

@unrenormalizable
Copy link
Contributor Author

unrenormalizable commented Jan 30, 2024

Great! Let me know when you start working on ONNX part.

will do & my apologies, taking way longer than expected. overestimated my learning abilities 🤣

@antimora
Copy link
Collaborator

No worries. We are here to help.

@antimora antimora added the feature The feature request label Jan 31, 2024
# Conflicts:
#	burn-autodiff/src/ops/tensor.rs
#	burn-candle/src/lib.rs
#	burn-candle/src/ops/tensor.rs
#	burn-fusion/src/ops/float.rs
#	burn-fusion/src/stream/operation.rs
#	burn-ndarray/src/ops/tensor.rs
#	burn-tch/src/ops/tensor.rs
#	burn-wgpu/src/ops/float_ops.rs
@unrenormalizable
Copy link
Contributor Author

Folks, I am going to pause the prod/prod_dim work till required dependencies are available.

The main blocker is cumprod is required to implement the autodiff/grad component for prod. This is as per the pytorch implementation which is a reasonable approach.

In addition I have filed the prod_axis requirements on huggingface/candle/1620 & rust-ndarray/ndarray/1351

Following items from this PR will be taken forward in a separate PR.

  1. reduce_dim_sum.wgsl edge case fix
  2. reduce_dim_shared_memory.wgsl initialization issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature The feature request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants