Skip to content

Commit

Permalink
Add support for recursive CTEs through with.recursive
Browse files Browse the repository at this point in the history
  • Loading branch information
ClearlyClaire committed Apr 19, 2024
1 parent 1842fba commit 073cb48
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 5 deletions.
49 changes: 46 additions & 3 deletions activerecord/lib/active_record/relation/query_methods.rb
Expand Up @@ -144,6 +144,23 @@ def scope_association_reflection(association)
end
end

# +WithChain+ objects act as placeholder for queries in which +with+ does not have any parameter.
# In this case, +with+ can be chained to return a new relation.
class WithChain
def initialize(scope) # :nodoc:
@scope = scope
end

# Returns a new relation in which Common Table Expressions (CTEs) are flagged as recursive.
#
# See QueryMethods#with for more details.
def recursive(*args)
@scope.with_values += args
@scope.with_is_recursive = true
@scope
end
end

# A wrapper to distinguish CTE joins from other nodes.
class CTEJoin # :nodoc:
attr_reader :name
Expand Down Expand Up @@ -179,6 +196,18 @@ def #{method_name}=(value) # def includes_values=(value)
CODE
end

# TODO: This is akin to how `Relation::VALUE_METHODS` are defined,
# but this does not neatly fit into one of the existing categories.
# Maybe we should make a full-fledged `WithClause`
def with_is_recursive
@values.fetch(:with_is_recursive, false)
end

def with_is_recursive=(value)
assert_mutability!
@values[:with_is_recursive] = value
end

alias extensions extending_values

# Specify associations +args+ to be eager loaded to prevent N + 1 queries.
Expand Down Expand Up @@ -446,6 +475,17 @@ def _select!(*fields) # :nodoc:
# # )
# # SELECT * FROM posts
#
# This can be used to write recursive CTEs:
#
# Post.with.recursive(post_and_replies: [Post.where(id: 42), Post.joins('JOIN post_and_replies ON posts.in_reply_to_id = post_and_replies.id')])
# # => ActiveRecord::Relation
# # WITH post_and_replies AS (
# # (SELECT * FROM posts WHERE id = 42)
# # UNION ALL
# # (SELECT * FROM posts JOIN posts_and_replies ON posts.in_reply_to_id = posts_and_replies.id)
# # )
# # SELECT * FROM posts
#
# Once you define Common Table Expression you can use custom +FROM+ value or +JOIN+ to reference it.
#
# Post.with(posts_with_tags: Post.where("tags_count > ?", 0)).from("posts_with_tags AS posts")
Expand Down Expand Up @@ -484,8 +524,11 @@ def _select!(*fields) # :nodoc:
# .with(posts_with_comments: Post.where("comments_count > ?", 0))
# .with(posts_with_tags: Post.where("tags_count > ?", 0))
def with(*args)
check_if_method_has_arguments!(__callee__, args)
spawn.with!(*args)
if args.empty?
WithChain.new(spawn)
else
spawn.with!(*args)
end
end

# Like #with, but modifies relation in place.
Expand Down Expand Up @@ -1829,7 +1872,7 @@ def build_with(arel)
build_with_value_from_hash(with_value)
end

arel.with(with_statements)
with_is_recursive ? arel.with(:recursive, with_statements) : arel.with(with_statements)
end

def build_with_value_from_hash(hash)
Expand Down
2 changes: 1 addition & 1 deletion activerecord/test/cases/relation/delegation_test.rb
Expand Up @@ -62,7 +62,7 @@ class QueryingMethodsDelegationTest < ActiveRecord::TestCase
ActiveRecord::SpawnMethods.public_instance_methods(false) - [:spawn, :merge!] +
ActiveRecord::QueryMethods.public_instance_methods(false).reject { |method|
method.end_with?("=", "!", "?", "value", "values", "clause")
} - [:reverse_order, :arel, :extensions, :construct_join_dependency] + [
} - [:reverse_order, :arel, :extensions, :construct_join_dependency, :with_is_recursive] + [
:any?, :many?, :none?, :one?,
:first_or_create, :first_or_create!, :first_or_initialize,
:find_or_create_by, :find_or_create_by!, :find_or_initialize_by,
Expand Down
10 changes: 9 additions & 1 deletion activerecord/test/cases/relation/with_test.rb
Expand Up @@ -57,7 +57,6 @@ def test_with_when_called_from_active_record_scope
end

def test_with_when_invalid_params_are_passed
assert_raise(ArgumentError) { Post.with }
assert_raise(ArgumentError) { Post.with(posts_with_tags: nil).load }
assert_raise(ArgumentError) { Post.with(posts_with_tags: [Post.where("tags_count > 0"), 5]).load }
end
Expand All @@ -73,6 +72,15 @@ def test_with_when_passing_arrays
assert_equal (POSTS_WITH_TAGS + POSTS_WITH_COMMENTS).sort, relation.order(:id).pluck(:id)
end

def test_with_recursive
# TODO: actually test recursive behavior
relation = Post
.with.recursive(posts_with_comments: Post.where("legacy_comments_count > 0"))
.from("posts_with_comments AS posts")

assert_equal POSTS_WITH_COMMENTS, relation.order(:id).pluck(:id)
end

def test_with_joins
relation = Post
.with(commented_posts: Comment.select(:post_id).distinct)
Expand Down

0 comments on commit 073cb48

Please sign in to comment.