Skip to content

Add kernel SparseSegmentSum for CPU and WebGL backend #5018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 3, 2021

Conversation

ahmedsabie
Copy link
Contributor

@ahmedsabie ahmedsabie commented Apr 30, 2021

ref #4838
TensorFlow python version is defined here
c++ kernel definition (SparseSegmentReductionSumOp) is here

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

@ahmedsabie ahmedsabie requested review from pyu10055 and lina128 April 30, 2021 21:42
@google-cla google-cla bot added the cla: yes label Apr 30, 2021
Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, great work!

Reviewable status: 0 of 1 approvals obtained (waiting on @ahmedsabie and @lina128)


tfjs-backend-cpu/src/kernels/SparseSegmentReduction_impl.ts, line 22 at r1 (raw file):

export function sparseSegmentReductionImpl(
    input: TypedArray, inputShape: number[], inputDType: DataType,
    indices: TypedArray, segmentIds: TypedArray, isMean = false,

parameters isMean is not used ?


tfjs-backend-cpu/src/kernels/SparseSegmentReduction_impl.ts, line 61 at r1 (raw file):

  const outputLength =
      outputShape.reduce((product, value) => product * value, 1);
  const output = util.getArrayFromDType(inputDType, outputLength) as TypedArray;

Please comment that the output type array is initialized with 0 by default.


tfjs-backend-cpu/src/kernels/SparseSegmentReduction_impl.ts, line 117 at r1 (raw file):

      }
    }
    // const badOffset =

removed here?


tfjs-core/src/ops/sparse/sparse_segment_sum.ts, line 32 at r1 (raw file):

 * // Select two rows, one segment.
 * const result1 = tf.sparse.sparseSegmentSum(c,
 *                                           tf.tensor1d([0, 1], 'int32),

missing right ' across all examples


tfjs-core/src/ops/sparse/sparse_segment_sum.ts, line 34 at r1 (raw file):

 *                                           tf.tensor1d([0, 1], 'int32),
 *                                           tf.tensor1d([0, 0], 'int32);
 * console.log(result1);

console.log() is duplicated with tensor.print() method`


tfjs-core/src/ops/sparse/sparse_segment_sum.ts, line 35 at r1 (raw file):

 *                                           tf.tensor1d([0, 0], 'int32);
 * console.log(result1);
 * result1['output'].print(); // [[0 0 0 0]]

the method only returns a single tensor, it should be result1.print()


tfjs-core/src/ops/sparse/sparse_segment_sum.ts, line 67 at r1 (raw file):

      convertToTensor(segmentIds, 'segmentIds', 'sparseSegmentSum');

  if ($data.rank < 1) {

does data need to be 2D tensor or at least 1D tensor? The jsDoc says 2D?


tfjs-core/src/ops/sparse/sparse_segment_sum_test.ts, line 38 at r1 (raw file):

}

// function sparseTensorValue2x3x4() {

delete this method if it is not used.


tfjs-core/src/ops/sparse/sparse_segment_sum_test.ts, line 69 at r1 (raw file):

    const result =
        tf.sparse.sparseSegmentSum(TensorValue10(), [8, 3, 0, 9], [0, 1, 2, 2]);
    expectArraysClose(await result.data(), [9, 4, 11]);

should this be [9, 4, 10]?


tfjs-core/src/ops/sparse/sparse_segment_sum_test.ts, line 148 at r1 (raw file):

  });

  it('segments invalid 7', async () => {

can you add tests for invalid input ranks?

Copy link
Contributor Author

@ahmedsabie ahmedsabie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @ahmedsabie, @lina128, and @pyu10055)


tfjs-backend-cpu/src/kernels/SparseSegmentReduction_impl.ts, line 22 at r1 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

parameters isMean is not used ?

good catch I will add it back in for the SparseSegmentMean


tfjs-core/src/ops/sparse/sparse_segment_sum.ts, line 67 at r1 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

does data need to be 2D tensor or at least 1D tensor? The jsDoc says 2D?

fixed it should be at least 1D


tfjs-core/src/ops/sparse/sparse_segment_sum_test.ts, line 69 at r1 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

should this be [9, 4, 10]?

the array has values 1 through 10 so indices 0 and 9 in segment 2 means value is 1 + 10 i believe


tfjs-core/src/ops/sparse/sparse_segment_sum_test.ts, line 148 at r1 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

can you add tests for invalid input ranks?

done

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @ahmedsabie and @lina128)


tfjs-core/src/ops/sparse/sparse_segment_sum_test.ts, line 69 at r1 (raw file):

Previously, ahmedsabie wrote…

the array has values 1 through 10 so indices 0 and 9 in segment 2 means value is 1 + 10 i believe

got it

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you need to disable the tests for WASM

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @ahmedsabie and @lina128)

@ahmedsabie ahmedsabie force-pushed the sparse-segment-sum branch from fe4cf8f to b5c6a8f Compare May 3, 2021 19:15
@ahmedsabie ahmedsabie merged commit 3152315 into tensorflow:master May 3, 2021
@ahmedsabie ahmedsabie deleted the sparse-segment-sum branch May 3, 2021 19:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants