diff --git a/activerecord/lib/active_record/associations/join_dependency.rb b/activerecord/lib/active_record/associations/join_dependency.rb index ec5c189cd31c5..231c8a5c1b7cf 100644 --- a/activerecord/lib/active_record/associations/join_dependency.rb +++ b/activerecord/lib/active_record/associations/join_dependency.rb @@ -45,32 +45,107 @@ def column_aliases Column = Struct.new(:name, :alias) end - attr_reader :alias_tracker, :base_klass, :join_root + class Tree # :nodoc: + def initialize(associations = nil) + @tree = {} + add_associations(associations) if associations + end + + def add_associations(associations) + walk(associations, @tree) + end + + def map(&block) + @tree.map(&block) + end - def self.make_tree(associations) - hash = {} - walk_tree associations, hash - hash + private + def walk(associations, hash, strict = true) # recursion is always strict + case associations + when Symbol, String + hash[associations.to_sym] ||= {} + when Array + associations.each do |assoc| + walk assoc, hash + end + when Hash + associations.each do |k,v| + cache = hash[k] ||= {} + walk v, cache + end + else + raise ConfigurationError, associations.inspect if strict + end + end + + def self.to_tree(associations = nil) + associations.kind_of?(self) ? associations : new(associations) + end end + + # Same as Tree, except it accepts associations only if these are valid + # AR association joins() params (ie: :books, {:author => :book}, but not + # 'JOINS books' or Arel::Nodes::Join objects). + class JoinsTree < Tree # :nodoc: + def association_join_param?(assocs) + # note that association joins() param can't be a String (strings passed to + # joins() must be literal/valid raw SQL joins), contrast this with Tree + # being able to walk() Strings (this is because Strings are valid includes(), + # references() params) + assocs.kind_of?(Symbol) || assocs.kind_of?(Hash) || assocs.kind_of?(Array) + end - def self.walk_tree(associations, hash) - case associations - when Symbol, String - hash[associations.to_sym] ||= {} - when Array - associations.each do |assoc| - walk_tree assoc, hash + def add_associations(assocs) + if association_join_param?(assocs) + super + true + else + false end - when Hash - associations.each do |k,v| - cache = hash[k] ||= {} - walk_tree v, cache + end + + def drain_associations_as_join_dependency_param(associations_param) + join_dependency_param = nil + drain(associations_param) do |associations_name, subtree, multiple_values_incoming| + if multiple_values_incoming + (join_dependency_param ||= {})[associations_name] = subtree + elsif subtree.empty? + join_dependency_param = associations_name # no need for Hash, can avoid allocation + else + join_dependency_param = {associations_name => subtree} + end + end + join_dependency_param + end + + def drain_associations_as_join_infos(join_dependency, associations_param) + join_infos = nil + drain(associations_param) do |association_name, subtree, multiple_values_incoming| + join_infos ||= [] + join_infos.concat(join_dependency.make_association_inner_join(association_name)) + end + join_infos + end + + private + def drain(associations_param) + case associations_param + when Symbol + if subtree = @tree.delete(associations_param) + yield associations_param, subtree, false + end + when Hash, Array + associations_param.public_send(associations_param.kind_of?(Hash) ? :each_key : :each) do |association_name| + if subtree = @tree.delete(association_name) + yield association_name, subtree, true + end + end end - else - raise ConfigurationError, associations.inspect end end + attr_reader :alias_tracker, :base_klass, :join_root + # base is the base class on which operation is taking place. # associations is the list of associations which are joined using hash, symbol or array. # joins is the list of all string join commands and arel nodes. @@ -92,10 +167,11 @@ def self.walk_tree(associations, hash) # associations # => [:appointments] # joins # => [] # - def initialize(base, associations, joins) + def initialize(base, associations, joins = []) @alias_tracker = AliasTracker.create(base.connection, joins) @alias_tracker.aliased_name_for(base.table_name, base.table_name) # Updates the count for base.table_name to 1 - tree = self.class.make_tree associations + # associations Hash can be used directly, no need to explicitly convert it into Tree + tree = associations.kind_of?(Hash) ? associations : Tree.to_tree(associations) @join_root = JoinBase.new base, build(tree, base) @join_root.children.each { |child| construct_tables! @join_root, child } end @@ -104,20 +180,14 @@ def reflections join_root.drop(1).map!(&:reflection) end - def join_constraints(outer_joins) - joins = join_root.children.flat_map { |child| - make_inner_joins join_root, child - } - - joins.concat outer_joins.flat_map { |oj| - if join_root.match? oj.join_root - walk join_root, oj.join_root - else - oj.join_root.children.flat_map { |child| - make_outer_joins oj.join_root, child - } - end - } + def join_constraints_for_join_dependency(join) + if join_root.match? join.join_root + walk join_root, join.join_root + else + join.join_root.children.flat_map { |child| + make_outer_joins join.join_root, child + } + end end def aliases @@ -150,6 +220,13 @@ def instantiate(result_set, aliases) parents.values end + def make_association_inner_join(association_name) + join_root.children.each do |child| + return make_inner_joins(join_root, child) if child.reflection.name == association_name + end + nil + end + private def make_constraints(parent, child, tables, join_type) diff --git a/activerecord/lib/active_record/relation/merger.rb b/activerecord/lib/active_record/relation/merger.rb index ac41d0aa80522..9134f81d34b34 100644 --- a/activerecord/lib/active_record/relation/merger.rb +++ b/activerecord/lib/active_record/relation/merger.rb @@ -83,26 +83,46 @@ def merge private def merge_joins - return if values[:joins].blank? + return if (joins = values[:joins]).blank? if other.klass == relation.klass - relation.joins!(*values[:joins]) + relation.joins!(*joins) else - joins_dependency, rest = values[:joins].partition do |join| - case join - when Hash, Symbol, Array - true - else - false + # 1) build an association join tree (AR guarantees not to double join + # associations even if they've been accidentally specified twice, + # ie: `Author.joins(:books).joins(:books)`) + assoc_joins_tree = ActiveRecord::Associations::JoinDependency::JoinsTree.new + joins.each {|join| assoc_joins_tree.add_associations(join)} + + # try to coalesce/pool JoinDependency allocation, since association joins usually come in batches, + # ie: joins # => [:posts, :comments, :categorizations], while non association joins are usually + # really rare + join_dependency_params = nil + join_values = [] + + # 2) build join_values iteratively to preserve user supplied JOIN clauses order + joins.each do |join| + # if AR "association" param, ie: :books, or {:author => :book} + if join_dependency_param = assoc_joins_tree.drain_associations_as_join_dependency_param(join) + join_dependency_params ||= [] + if join_dependency_param.kind_of?(Array) + join_dependency_params.concat(join_dependency_param) + else + join_dependency_params << join_dependency_param + end + elsif assoc_joins_tree.association_join_param?(join) + # elsif `join` is an already "drained" association join + else # else `join` is not an association join (but a string or an arel join obj) + if join_dependency_params # can't delay instantiating JoinDependency anymore + join_values << ActiveRecord::Associations::JoinDependency.new(other.klass, join_dependency_params) + join_dependency_params = nil + end + join_values << join end end + join_values << ActiveRecord::Associations::JoinDependency.new(other.klass, join_dependency_params) if join_dependency_params - join_dependency = ActiveRecord::Associations::JoinDependency.new(other.klass, - joins_dependency, - []) - relation.joins! rest - - @relation = relation.joins join_dependency + @relation = relation.joins(*join_values) end end diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb index bbddd28cccbc7..abe63717694c8 100644 --- a/activerecord/lib/active_record/relation/query_methods.rb +++ b/activerecord/lib/active_record/relation/query_methods.rb @@ -922,22 +922,6 @@ def where_unscoping(target_value) bind_values.reject! { |col,_| col.name == target_value } end - def custom_join_ast(table, joins) - joins = joins.reject(&:blank?) - - return [] if joins.empty? - - joins.map! do |join| - case join - when Array - join = Arel.sql(join.join(' ')) if array_of_strings?(join) - when String - join = Arel.sql(join) - end - table.create_string_join(join) - end - end - def collapse_wheres(arel, wheres) predicates = wheres.map do |where| next where if ::Arel::Nodes::Equality === where @@ -1004,44 +988,58 @@ def build_from end def build_joins(manager, joins) - buckets = joins.group_by do |join| - case join - when String - :string_join - when Hash, Symbol, Array - :association_join - when ActiveRecord::Associations::JoinDependency - :stashed_join - when Arel::Nodes::Join - :join_node - else - raise 'unknown class: %s' % join.class.name + assoc_joins_tree = ActiveRecord::Associations::JoinDependency::JoinsTree.new + other_joins_hash = {} + joins_to_process = [] + + # Joins need to be iterated over twice: + # 1) first loop over supplied joins to do some pre-processing, uniquification and + # preparation of params for JoinDependency (it needs to be aware of all joins, so it + # can perform non-conflicting table aliasing) + joins.each do |join| + unless assoc_joins_tree.add_associations(join) + case join + when ActiveRecord::Associations::JoinDependency + when String + next if (join = join.strip).blank? || other_joins_hash[join] + join = other_joins_hash[join] = manager.create_string_join(Arel.sql(join)) + when Arel::Nodes::Join + # note: Arel::Nodes::Join subclasses can reliably be used as hash keys + next if other_joins_hash[join] + other_joins_hash[join] = join + else + raise 'unknown class: %s' % join.class.name + end end + joins_to_process << join end - association_joins = buckets[:association_join] || [] - stashed_association_joins = buckets[:stashed_join] || [] - join_nodes = (buckets[:join_node] || []).uniq - string_joins = (buckets[:string_join] || []).map(&:strip).uniq - - join_list = join_nodes + custom_join_ast(manager, string_joins) - join_dependency = ActiveRecord::Associations::JoinDependency.new( @klass, - association_joins, - join_list + assoc_joins_tree, + other_joins_hash.values ) - join_infos = join_dependency.join_constraints stashed_association_joins + # 2) loop over joins again (so that user supplied order of joins is maintained) + # and send them to the manager + joins_to_process.each do |join| + if join_infos = assoc_joins_tree.drain_associations_as_join_infos(join_dependency, join) + append_join_infos(manager, join_infos) + elsif join.kind_of?(ActiveRecord::Associations::JoinDependency) + append_join_infos(manager, join_dependency.join_constraints_for_join_dependency(join)) + elsif join.kind_of?(Arel::Nodes::Join) + manager.join_sources << join + end + end + + manager + end + def append_join_infos(manager, join_infos) join_infos.each do |info| info.joins.each { |join| manager.from(join) } manager.bind_values.concat info.binds end - - manager.join_sources.concat(join_list) - - manager end def build_select(arel, selects) @@ -1073,10 +1071,6 @@ def reverse_sql_order(order_query) end end - def array_of_strings?(o) - o.is_a?(Array) && o.all? { |obj| obj.is_a?(String) } - end - def build_order(arel) orders = order_values.uniq orders.reject!(&:blank?) diff --git a/activerecord/test/cases/associations/inner_join_association_test.rb b/activerecord/test/cases/associations/inner_join_association_test.rb index 07cf65a760b7b..807e3548a5b2f 100644 --- a/activerecord/test/cases/associations/inner_join_association_test.rb +++ b/activerecord/test/cases/associations/inner_join_association_test.rb @@ -2,6 +2,7 @@ require 'models/post' require 'models/comment' require 'models/author' +require 'models/book' require 'models/essay' require 'models/category' require 'models/categorization' @@ -136,4 +137,35 @@ def test_find_with_conditions_on_through_reflection assert_equal 0, categories.first.special_categorizations.size assert_equal 1, categories.second.special_categorizations.size end + + test "join clauses are emitted in user specified order" do + # test permutes all possible ways to do a join in AR and tests that user provided JOIN clause order is preserved in the emitted SQL + [:posts, :categorizations, :essays].map do |association_sym| + author_table = Author.arel_table + reflection = Author.reflect_on_association(association_sym) + other_table = reflection.klass.arel_table + arel_join = author_table.create_join(other_table, author_table.create_on(other_table[reflection.foreign_key].eq(author_table[reflection.association_primary_key]))) + [association_sym, arel_join, arel_join.to_sql] + end.permutation do |join_a, join_b, join_c| + join_a.each do |join_a_version| + join_b.each do |join_b_version| + join_c.each do |join_c_version| + join_a_token = join_a_version.respond_to?(:to_sql) ? join_a_version.to_sql : join_a_version.to_s + join_b_token = join_b_version.respond_to?(:to_sql) ? join_b_version.to_sql : join_b_version.to_s + join_c_token = join_c_version.respond_to?(:to_sql) ? join_c_version.to_sql : join_c_version.to_s + + # this tests sql generation in AR::Relation::QueryMethods#build_joins + sql = Author.joins(join_a_version, join_b_version, join_c_version).to_sql + assert(sql.index(join_a_token) < sql.index(join_b_token), "#{join_a_token.inspect} must precede #{join_b_token.inspect} in #{sql.inspect}") + assert(sql.index(join_b_token) < sql.index(join_c_token), "#{join_b_token.inspect} must precede #{join_c_token.inspect} in #{sql.inspect}") + + # this tests relation join merging in AR::Relation::Merger#merge_joins + sql = Book.all.merge(Author.joins(join_a_version, join_b_version, join_c_version)).to_sql + assert(sql.index(join_a_token) < sql.index(join_b_token), "#{join_a_token.inspect} must precede #{join_b_token.inspect} in #{sql.inspect}") + assert(sql.index(join_b_token) < sql.index(join_c_token), "#{join_b_token.inspect} must precede #{join_c_token.inspect} in #{sql.inspect}") + end + end + end + end + end end