Skip to content

Commit

Permalink
Preserve user supplied JOIN order.
Browse files Browse the repository at this point in the history
JOIN clauses order is important, previous implementation always
put string or arel joins at then end (after auto-generated
association joins).

Fixes rails#12953, rails#15488, rails#16635.
  • Loading branch information
thedarkone committed Sep 12, 2014
1 parent 2dd6ec2 commit 2c80962
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 94 deletions.
145 changes: 111 additions & 34 deletions activerecord/lib/active_record/associations/join_dependency.rb
Expand Up @@ -45,32 +45,107 @@ def column_aliases
Column = Struct.new(:name, :alias)
end

attr_reader :alias_tracker, :base_klass, :join_root
class Tree # :nodoc:
def initialize(associations = nil)
@tree = {}
add_associations(associations) if associations
end

def add_associations(associations)
walk(associations, @tree)
end

def map(&block)
@tree.map(&block)
end

def self.make_tree(associations)
hash = {}
walk_tree associations, hash
hash
private
def walk(associations, hash, strict = true) # recursion is always strict
case associations
when Symbol, String
hash[associations.to_sym] ||= {}
when Array
associations.each do |assoc|
walk assoc, hash
end
when Hash
associations.each do |k,v|
cache = hash[k] ||= {}
walk v, cache
end
else
raise ConfigurationError, associations.inspect if strict
end
end

def self.to_tree(associations = nil)
associations.kind_of?(self) ? associations : new(associations)
end
end

# Same as Tree, except it accepts associations only if these are valid
# AR association joins() params (ie: :books, {:author => :book}, but not
# 'JOINS books' or Arel::Nodes::Join objects).
class JoinsTree < Tree # :nodoc:
def association_join_param?(assocs)
# note that association joins() param can't be a String (strings passed to
# joins() must be literal/valid raw SQL joins), contrast this with Tree
# being able to walk() Strings (this is because Strings are valid includes(),
# references() params)
assocs.kind_of?(Symbol) || assocs.kind_of?(Hash) || assocs.kind_of?(Array)
end

def self.walk_tree(associations, hash)
case associations
when Symbol, String
hash[associations.to_sym] ||= {}
when Array
associations.each do |assoc|
walk_tree assoc, hash
def add_associations(assocs)
if association_join_param?(assocs)
super
true
else
false
end
when Hash
associations.each do |k,v|
cache = hash[k] ||= {}
walk_tree v, cache
end

def drain_associations_as_join_dependency_param(associations_param)
join_dependency_param = nil
drain(associations_param) do |associations_name, subtree, multiple_values_incoming|
if multiple_values_incoming
(join_dependency_param ||= {})[associations_name] = subtree
elsif subtree.empty?
join_dependency_param = associations_name # no need for Hash, can avoid allocation
else
join_dependency_param = {associations_name => subtree}
end
end
join_dependency_param
end

def drain_associations_as_join_infos(join_dependency, associations_param)
join_infos = nil
drain(associations_param) do |association_name, subtree, multiple_values_incoming|
join_infos ||= []
join_infos.concat(join_dependency.make_association_inner_join(association_name))
end
join_infos
end

private
def drain(associations_param)
case associations_param
when Symbol
if subtree = @tree.delete(associations_param)
yield associations_param, subtree, false
end
when Hash, Array
associations_param.public_send(associations_param.kind_of?(Hash) ? :each_key : :each) do |association_name|
if subtree = @tree.delete(association_name)
yield association_name, subtree, true
end
end
end
else
raise ConfigurationError, associations.inspect
end
end

attr_reader :alias_tracker, :base_klass, :join_root

# base is the base class on which operation is taking place.
# associations is the list of associations which are joined using hash, symbol or array.
# joins is the list of all string join commands and arel nodes.
Expand All @@ -92,10 +167,11 @@ def self.walk_tree(associations, hash)
# associations # => [:appointments]
# joins # => []
#
def initialize(base, associations, joins)
def initialize(base, associations, joins = [])
@alias_tracker = AliasTracker.create(base.connection, joins)
@alias_tracker.aliased_name_for(base.table_name, base.table_name) # Updates the count for base.table_name to 1
tree = self.class.make_tree associations
# associations Hash can be used directly, no need to explicitly convert it into Tree
tree = associations.kind_of?(Hash) ? associations : Tree.to_tree(associations)
@join_root = JoinBase.new base, build(tree, base)
@join_root.children.each { |child| construct_tables! @join_root, child }
end
Expand All @@ -104,20 +180,14 @@ def reflections
join_root.drop(1).map!(&:reflection)
end

def join_constraints(outer_joins)
joins = join_root.children.flat_map { |child|
make_inner_joins join_root, child
}

joins.concat outer_joins.flat_map { |oj|
if join_root.match? oj.join_root
walk join_root, oj.join_root
else
oj.join_root.children.flat_map { |child|
make_outer_joins oj.join_root, child
}
end
}
def join_constraints_for_join_dependency(join)
if join_root.match? join.join_root
walk join_root, join.join_root
else
join.join_root.children.flat_map { |child|
make_outer_joins join.join_root, child
}
end
end

def aliases
Expand Down Expand Up @@ -150,6 +220,13 @@ def instantiate(result_set, aliases)
parents.values
end

def make_association_inner_join(association_name)
join_root.children.each do |child|
return make_inner_joins(join_root, child) if child.reflection.name == association_name
end
nil
end

private

def make_constraints(parent, child, tables, join_type)
Expand Down
48 changes: 34 additions & 14 deletions activerecord/lib/active_record/relation/merger.rb
Expand Up @@ -83,26 +83,46 @@ def merge
private

def merge_joins
return if values[:joins].blank?
return if (joins = values[:joins]).blank?

if other.klass == relation.klass
relation.joins!(*values[:joins])
relation.joins!(*joins)
else
joins_dependency, rest = values[:joins].partition do |join|
case join
when Hash, Symbol, Array
true
else
false
# 1) build an association join tree (AR guarantees not to double join
# associations even if they've been accidentally specified twice,
# ie: `Author.joins(:books).joins(:books)`)
assoc_joins_tree = ActiveRecord::Associations::JoinDependency::JoinsTree.new
joins.each {|join| assoc_joins_tree.add_associations(join)}

# try to coalesce/pool JoinDependency allocation, since association joins usually come in batches,
# ie: joins # => [:posts, :comments, :categorizations], while non association joins are usually
# really rare
join_dependency_params = nil
join_values = []

# 2) build join_values iteratively to preserve user supplied JOIN clauses order
joins.each do |join|
# if AR "association" param, ie: :books, or {:author => :book}
if join_dependency_param = assoc_joins_tree.drain_associations_as_join_dependency_param(join)
join_dependency_params ||= []
if join_dependency_param.kind_of?(Array)
join_dependency_params.concat(join_dependency_param)
else
join_dependency_params << join_dependency_param
end
elsif assoc_joins_tree.association_join_param?(join)
# elsif `join` is an already "drained" association join
else # else `join` is not an association join (but a string or an arel join obj)
if join_dependency_params # can't delay instantiating JoinDependency anymore
join_values << ActiveRecord::Associations::JoinDependency.new(other.klass, join_dependency_params)
join_dependency_params = nil
end
join_values << join
end
end
join_values << ActiveRecord::Associations::JoinDependency.new(other.klass, join_dependency_params) if join_dependency_params

join_dependency = ActiveRecord::Associations::JoinDependency.new(other.klass,
joins_dependency,
[])
relation.joins! rest

@relation = relation.joins join_dependency
@relation = relation.joins(*join_values)
end
end

Expand Down
86 changes: 40 additions & 46 deletions activerecord/lib/active_record/relation/query_methods.rb
Expand Up @@ -922,22 +922,6 @@ def where_unscoping(target_value)
bind_values.reject! { |col,_| col.name == target_value }
end

def custom_join_ast(table, joins)
joins = joins.reject(&:blank?)

return [] if joins.empty?

joins.map! do |join|
case join
when Array
join = Arel.sql(join.join(' ')) if array_of_strings?(join)
when String
join = Arel.sql(join)
end
table.create_string_join(join)
end
end

def collapse_wheres(arel, wheres)
predicates = wheres.map do |where|
next where if ::Arel::Nodes::Equality === where
Expand Down Expand Up @@ -1004,44 +988,58 @@ def build_from
end

def build_joins(manager, joins)
buckets = joins.group_by do |join|
case join
when String
:string_join
when Hash, Symbol, Array
:association_join
when ActiveRecord::Associations::JoinDependency
:stashed_join
when Arel::Nodes::Join
:join_node
else
raise 'unknown class: %s' % join.class.name
assoc_joins_tree = ActiveRecord::Associations::JoinDependency::JoinsTree.new
other_joins_hash = {}
joins_to_process = []

# Joins need to be iterated over twice:
# 1) first loop over supplied joins to do some pre-processing, uniquification and
# preparation of params for JoinDependency (it needs to be aware of all joins, so it
# can perform non-conflicting table aliasing)
joins.each do |join|
unless assoc_joins_tree.add_associations(join)
case join
when ActiveRecord::Associations::JoinDependency
when String
next if (join = join.strip).blank? || other_joins_hash[join]
join = other_joins_hash[join] = manager.create_string_join(Arel.sql(join))
when Arel::Nodes::Join
# note: Arel::Nodes::Join subclasses can reliably be used as hash keys
next if other_joins_hash[join]
other_joins_hash[join] = join
else
raise 'unknown class: %s' % join.class.name
end
end
joins_to_process << join
end

association_joins = buckets[:association_join] || []
stashed_association_joins = buckets[:stashed_join] || []
join_nodes = (buckets[:join_node] || []).uniq
string_joins = (buckets[:string_join] || []).map(&:strip).uniq

join_list = join_nodes + custom_join_ast(manager, string_joins)

join_dependency = ActiveRecord::Associations::JoinDependency.new(
@klass,
association_joins,
join_list
assoc_joins_tree,
other_joins_hash.values
)

join_infos = join_dependency.join_constraints stashed_association_joins
# 2) loop over joins again (so that user supplied order of joins is maintained)
# and send them to the manager
joins_to_process.each do |join|
if join_infos = assoc_joins_tree.drain_associations_as_join_infos(join_dependency, join)
append_join_infos(manager, join_infos)
elsif join.kind_of?(ActiveRecord::Associations::JoinDependency)
append_join_infos(manager, join_dependency.join_constraints_for_join_dependency(join))
elsif join.kind_of?(Arel::Nodes::Join)
manager.join_sources << join
end
end

manager
end

def append_join_infos(manager, join_infos)
join_infos.each do |info|
info.joins.each { |join| manager.from(join) }
manager.bind_values.concat info.binds
end

manager.join_sources.concat(join_list)

manager
end

def build_select(arel, selects)
Expand Down Expand Up @@ -1073,10 +1071,6 @@ def reverse_sql_order(order_query)
end
end

def array_of_strings?(o)
o.is_a?(Array) && o.all? { |obj| obj.is_a?(String) }
end

def build_order(arel)
orders = order_values.uniq
orders.reject!(&:blank?)
Expand Down

0 comments on commit 2c80962

Please sign in to comment.