Skip to content

Commit

Permalink
A simpler and significantly more effective algorithm for memory term …
Browse files Browse the repository at this point in the history
…reduction.

PiperOrigin-RevId: 629241334
  • Loading branch information
tensorflower-gardener committed Apr 30, 2024
1 parent 3fb103c commit 6c238b6
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 14 deletions.
Expand Up @@ -38,8 +38,8 @@ using LiveIdx = int64_t; // Indexes into the liveness range (like a time point)
using GroupIdx = int64_t; // Indexes into the list of groups

using PrimPair = std::pair<PrimIdx, PrimIdx>;
using LiveAndPrim = std::pair<LiveIdx, PrimIdx>;
using Interval = std::pair<LiveIdx, LiveIdx>;
using ActivePrim = std::pair<Interval, PrimIdx>;

bool IsValid(const Interval& interval) {
return interval.first <= interval.second;
Expand Down Expand Up @@ -209,24 +209,21 @@ void MemoryTermReducer::Reduce(int64_t num_lives, int64_t num_primitives,
// A function to sweep through live points & merge large overlaps.
auto SweepAndMerge = [&num_lives, &enter, &evict, &CalcOverlap, &CalcNumTerms,
&MergeIntoGroup, &UpdatePrimitive, this]() -> bool {
absl::btree_set<LiveAndPrim> actives; // Active prims sorted by first value
absl::btree_set<ActivePrim> actives; // Active prims sorted by interval.
absl::btree_multimap<int64_t, PrimPair> overlaps;
for (LiveIdx live_idx = 0; live_idx < num_lives; ++live_idx) {
for (const PrimIdx prim_idx : enter[live_idx]) {
actives.insert({live_idx, prim_idx});
actives.insert({reduced_intervals_[prim_idx], prim_idx});
}
for (const PrimIdx prim_idx : evict[live_idx]) {
std::optional<LiveAndPrim> active; // The active prim we overlap with.
auto prim =
actives.find({reduced_intervals_[prim_idx].first, prim_idx});
if (auto next = prim; ++next != actives.end()) {
active = *next; // Choose a prim that started soon after we did.
} else if (actives.begin()->second != prim_idx) {
active = *actives.begin(); // Otherwise, choose the earliest prim.
}
actives.erase(prim);
if (!active) continue;
overlaps.insert({active->first - live_idx, {prim_idx, active->second}});
auto active = actives.find({reduced_intervals_[prim_idx], prim_idx});
if (++active == actives.end()) continue; // No prims left to merge with
std::optional<Interval> overlap = CalcOverlap(prim_idx, active->second);
if (!overlap) continue;
overlaps.insert({-length(*overlap), {prim_idx, active->second}});
}
for (const PrimIdx prim_idx : evict[live_idx]) {
actives.erase({reduced_intervals_[prim_idx], prim_idx});
}
}
bool changed = false;
Expand Down
Expand Up @@ -606,6 +606,122 @@ TEST(AutoShardingMemoryTest, OneIterationOnly) {
EXPECT_EQ(reducer.GetReducedGroups(), expected_reduced_groups);
}

// | ==> | 55555
// | ==> | 44444444444
// | 33333 ==> | ..... Groups:
// | 22222222 ==> | .....222 m[4] = m[0] + m[1]
// | 11111111111 ==> | ........... m[5] = m[2] + m[3]
// | 00000000000000 ==> | ...........000
// +--------------> ==> +-------------->
// (time) (time)
TEST(AutoShardingMemoryTest, StairsBottomLeft) {
const std::vector<std::pair<int64_t, int64_t>> intervals =
{{0, 13}, {0, 10}, {0, 7}, {0, 4}};

MemoryTermReducer reducer;
const auto num_terms = reducer.Reduce(/*num_lives=*/14, /*num_primitives=*/4,
Convert(intervals),
/*max_iterations=*/1);

const std::vector<std::vector<int64_t>> expected_reduced_live = {};
const std::vector<std::pair<int64_t, int64_t>> expected_reduced_intervals =
{{11, 13}, {11, -1}, {5, 7}, {5, -1}, {0, 10}, {0, 4}};
const std::vector<absl::btree_set<int64_t>> expected_reduced_groups =
{{0, 1}, {2, 3}};
const std::pair<int64_t, int64_t> expected_num_terms = {38, 26};
EXPECT_EQ(num_terms, expected_num_terms);
EXPECT_EQ(reducer.GetReducedLive(), expected_reduced_live);
EXPECT_EQ(reducer.GetReducedIntervals(), expected_reduced_intervals);
EXPECT_EQ(reducer.GetReducedGroups(), expected_reduced_groups);
}

// | ==> | 55555
// | ==> | 44444444444
// | 33333333333333 ==> | ...........333 Groups:
// | 22222222222 ==> | ........... m[4] = m[2] + m[3]
// | 11111111 ==> | .....111 m[5] = m[0] + m[1]
// | 00000 ==> | .....
// +--------------> ==> +-------------->
// (time) (time)
TEST(AutoShardingMemoryTest, StairsTopLeft) {
const std::vector<std::pair<int64_t, int64_t>> intervals =
{{0, 4}, {0, 7}, {0, 10}, {0, 13}};

MemoryTermReducer reducer;
const auto num_terms = reducer.Reduce(/*num_lives=*/14, /*num_primitives=*/4,
Convert(intervals),
/*max_iterations=*/1);

const std::vector<std::vector<int64_t>> expected_reduced_live = {};
const std::vector<std::pair<int64_t, int64_t>> expected_reduced_intervals =
{{5, -1}, {5, 7}, {11, -1}, {11, 13}, {0, 10}, {0, 4}};
const std::vector<absl::btree_set<int64_t>> expected_reduced_groups =
{{2, 3}, {0, 1}};
const std::pair<int64_t, int64_t> expected_num_terms = {38, 26};
EXPECT_EQ(num_terms, expected_num_terms);
EXPECT_EQ(reducer.GetReducedLive(), expected_reduced_live);
EXPECT_EQ(reducer.GetReducedIntervals(), expected_reduced_intervals);
EXPECT_EQ(reducer.GetReducedGroups(), expected_reduced_groups);
}

// | ==> | 55555
// | ==> | 44444444444
// | 33333333333333 ==> | 333........... Groups:
// | 22222222222 ==> | ........... m[4] = m[2] + m[3]
// | 11111111 ==> | 111..... m[5] = m[0] + m[1]
// | 00000 ==> | .....
// +--------------> ==> +-------------->
// (time) (time)
TEST(AutoShardingMemoryTest, StairsTopRight) {
const std::vector<std::pair<int64_t, int64_t>> intervals =
{{9, 13}, {6, 13}, {3, 13}, {0, 13}};

MemoryTermReducer reducer;
const auto num_terms = reducer.Reduce(/*num_lives=*/14, /*num_primitives=*/4,
Convert(intervals),
/*max_iterations=*/1);

const std::vector<std::vector<int64_t>> expected_reduced_live = {};
const std::vector<std::pair<int64_t, int64_t>> expected_reduced_intervals =
{{14, 8}, {6, 8}, {14, 2}, {0, 2}, {3, 13}, {9, 13}};
const std::vector<absl::btree_set<int64_t>> expected_reduced_groups =
{{2, 3}, {0, 1}};
const std::pair<int64_t, int64_t> expected_num_terms = {38, 26};
EXPECT_EQ(num_terms, expected_num_terms);
EXPECT_EQ(reducer.GetReducedLive(), expected_reduced_live);
EXPECT_EQ(reducer.GetReducedIntervals(), expected_reduced_intervals);
EXPECT_EQ(reducer.GetReducedGroups(), expected_reduced_groups);
}

// | ==> | 55555
// | ==> | 44444444444
// | 33333 ==> | ..... Groups:
// | 22222222 ==> | 222..... m[4] = m[0] + m[1]
// | 11111111111 ==> | ........... m[5] = m[2] + m[3]
// | 00000000000000 ==> | 000...........
// +--------------> ==> +-------------->
// (time) (time)
TEST(AutoShardingMemoryTest, StairsBottomRight) {
const std::vector<std::pair<int64_t, int64_t>> intervals =
{{0, 13}, {3, 13}, {6, 13}, {9, 13}};

MemoryTermReducer reducer;
const auto num_terms = reducer.Reduce(/*num_lives=*/14, /*num_primitives=*/4,
Convert(intervals),
/*max_iterations=*/1);

const std::vector<std::vector<int64_t>> expected_reduced_live = {};
const std::vector<std::pair<int64_t, int64_t>> expected_reduced_intervals =
{{0, 2}, {14, 2}, {6, 8}, {14, 8}, {3, 13}, {9, 13}};
const std::vector<absl::btree_set<int64_t>> expected_reduced_groups =
{{0, 1}, {2, 3}};
const std::pair<int64_t, int64_t> expected_num_terms = {38, 26};
EXPECT_EQ(num_terms, expected_num_terms);
EXPECT_EQ(reducer.GetReducedLive(), expected_reduced_live);
EXPECT_EQ(reducer.GetReducedIntervals(), expected_reduced_intervals);
EXPECT_EQ(reducer.GetReducedGroups(), expected_reduced_groups);
}

// clang-format on

} // namespace
Expand Down

0 comments on commit 6c238b6

Please sign in to comment.