Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for recursive CTEs in ActiveRecord #51601

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 24 additions & 0 deletions activerecord/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
* Added support for recursive commont table expressions.

```ruby
Post.with_recursive(
post_and_replies: [
Post.where(id: 42),
Post.joins('JOIN post_and_replies ON posts.in_reply_to_id = post_and_replies.id'),
]
)
```

Generates the following SQL:

```sql
WITH RECURSIVE "post_and_replies" AS (
(SELECT "posts".* FROM "posts" WHERE "posts"."id" = 42)
UNION ALL
(SELECT "posts".* FROM "posts" JOIN post_and_replies ON posts.in_reply_to_id = post_and_replies.id)
)
SELECT "posts".* FROM "posts"
```

*ClearlyClaire*

* `validate_constraint` can be called in a `change_table` block.

ex:
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/active_record/querying.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ module Querying
:and, :or, :annotate, :optimizer_hints, :extending,
:having, :create_with, :distinct, :references, :none, :unscope, :merge, :except, :only,
:count, :average, :minimum, :maximum, :sum, :calculate,
:pluck, :pick, :ids, :async_ids, :strict_loading, :excluding, :without, :with,
:pluck, :pick, :ids, :async_ids, :strict_loading, :excluding, :without, :with, :with_recursive,
:async_count, :async_average, :async_minimum, :async_maximum, :async_sum, :async_pluck, :async_pick,
].freeze # :nodoc:
delegate(*QUERYING_METHODS, to: :all)
Expand Down
2 changes: 1 addition & 1 deletion activerecord/lib/active_record/relation.rb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def exec_explain(&block)
:reverse_order, :distinct, :create_with, :skip_query_cache]

CLAUSE_METHODS = [:where, :having, :from]
INVALID_METHODS_FOR_DELETE_ALL = [:distinct, :with]
INVALID_METHODS_FOR_DELETE_ALL = [:distinct, :with, :with_recursive]

VALUE_METHODS = MULTI_VALUE_METHODS + SINGLE_VALUE_METHODS + CLAUSE_METHODS

Expand Down
65 changes: 54 additions & 11 deletions activerecord/lib/active_record/relation/query_methods.rb
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,17 @@ def _select!(*fields) # :nodoc:
# # )
# # SELECT * FROM posts
#
# You can also pass an array of sub-queries to be joined in a +UNION ALL+.
#
# Post.with(posts_with_tags_or_comments: [Post.where("tags_count > ?", 0), Post.where("comments_count > ?", 0)])
# # => ActiveRecord::Relation
# # WITH posts_with_tags_or_comments AS (
# # (SELECT * FROM posts WHERE (tags_count > 0))
# # UNION ALL
# # (SELECT * FROM posts WHERE (comments_count > 0))
# # )
# # SELECT * FROM posts
#
# Once you define Common Table Expression you can use custom +FROM+ value or +JOIN+ to reference it.
#
# Post.with(posts_with_tags: Post.where("tags_count > ?", 0)).from("posts_with_tags AS posts")
Expand Down Expand Up @@ -475,7 +486,12 @@ def _select!(*fields) # :nodoc:
def with(*args)
raise ArgumentError, "ActiveRecord::Relation#with does not accept a block" if block_given?
check_if_method_has_arguments!(__callee__, args)
spawn.with!(*args)

if args.empty?
WithChain.new(spawn)
else
spawn.with!(*args)
end
end

# Like #with, but modifies relation in place.
Expand All @@ -484,6 +500,30 @@ def with!(*args) # :nodoc:
self
end

# Add a recursive Common Table Expression (CTE) that you can then reference within another SELECT statement.
#
# Post.with_recursive(post_and_replies: [Post.where(id: 42), Post.joins('JOIN post_and_replies ON posts.in_reply_to_id = post_and_replies.id')])
# # => ActiveRecord::Relation
# # WITH post_and_replies AS (
# # (SELECT * FROM posts WHERE id = 42)
# # UNION ALL
# # (SELECT * FROM posts JOIN posts_and_replies ON posts.in_reply_to_id = posts_and_replies.id)
# # )
# # SELECT * FROM posts
#
# See `#with` for more information.
def with_recursive(*args)
check_if_method_has_arguments!(__callee__, args)
spawn.with_recursive!(*args)
end

# Like #with_recursive but modifies the relation in place.
def with_recursive!(*args) # :nodoc:
self.with_values += args
@with_is_recursive = true
self
end

# Allows you to change a previously set select statement.
#
# Post.select(:title, :body)
Expand Down Expand Up @@ -1846,20 +1886,23 @@ def build_with(arel)
build_with_value_from_hash(with_value)
end

arel.with(with_statements)
@with_is_recursive ? arel.with(:recursive, with_statements) : arel.with(with_statements)
end

def build_with_value_from_hash(hash)
hash.map do |name, value|
expression =
case value
when Arel::Nodes::SqlLiteral then Arel::Nodes::Grouping.new(value)
when ActiveRecord::Relation then value.arel
when Arel::SelectManager then value
else
raise ArgumentError, "Unsupported argument type: `#{value}` #{value.class}"
end
Arel::Nodes::TableAlias.new(expression, name)
Arel::Nodes::TableAlias.new(build_with_expression_from_value(value), name)
end
end

def build_with_expression_from_value(value)
case value
when Arel::Nodes::SqlLiteral then Arel::Nodes::Grouping.new(value)
when ActiveRecord::Relation then value.arel
when Arel::SelectManager then value
when Array then value.map { |q| build_with_expression_from_value(q) }.reduce { |result, value| result.union(:all, value) }
else
raise ArgumentError, "Unsupported argument type: `#{value}` #{value.class}"
end
end

Expand Down
34 changes: 30 additions & 4 deletions activerecord/test/cases/relation/with_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
require "cases/helper"
require "models/comment"
require "models/post"
require "models/company"

module ActiveRecord
class WithTest < ActiveRecord::TestCase
fixtures :comments
fixtures :posts
fixtures :comments, :posts, :companies

POSTS_WITH_TAGS = [1, 2, 7, 8, 9, 10, 11].freeze
POSTS_WITH_COMMENTS = [1, 2, 4, 5, 7].freeze
Expand Down Expand Up @@ -57,9 +57,35 @@ def test_with_when_called_from_active_record_scope
end

def test_with_when_invalid_params_are_passed
assert_raise(ArgumentError) { Post.with }
assert_raise(ArgumentError) { Post.with(posts_with_tags: nil).load }
assert_raise(ArgumentError) { Post.with(posts_with_tags: [Post.where("tags_count > 0")]).load }
assert_raise(ArgumentError) { Post.with(posts_with_tags: [Post.where("tags_count > 0"), 5]).load }
end

def test_with_when_passing_arrays
relation = Post
.with(posts_with_tags_or_comments: [
Post.where("tags_count > 0"),
Post.where("legacy_comments_count > 0")
])
.from("posts_with_tags_or_comments AS posts")

assert_equal (POSTS_WITH_TAGS + POSTS_WITH_COMMENTS).sort, relation.order(:id).pluck(:id)
end

def test_with_recursive
top_companies = Company.where(firm_id: nil).to_a
child_companies = Company.where(firm_id: top_companies).to_a
top_companies_and_children = (top_companies.map(&:id) + child_companies.map(&:id)).sort

relation = Company.with_recursive(
top_companies_and_children: [
Company.where(firm_id: nil),
Company.joins("JOIN top_companies_and_children ON companies.firm_id = top_companies_and_children.id"),
]
).from("top_companies_and_children AS companies")

assert_equal top_companies_and_children, relation.order(:id).pluck(:id)
assert_match "WITH RECURSIVE", relation.to_sql
end

def test_with_joins
Expand Down