From 9c563a2bcde360b868cfb8854239c34072d0c798 Mon Sep 17 00:00:00 2001 From: Alex Tan Date: Thu, 28 Mar 2024 19:24:57 +0700 Subject: [PATCH] Use `type_for_attribute` to determine primary_key for `find` --- .../dsl/compilers/active_record_relations.rb | 28 ++++++++++++++++-- lib/tapioca/helpers/rbi_helper.rb | 9 ++++++ .../compilers/active_record_relations_spec.rb | 29 +++++++++++++++++-- spec/tapioca/helpers/rbi_helper_spec.rb | 9 ++++++ 4 files changed, 70 insertions(+), 5 deletions(-) diff --git a/lib/tapioca/dsl/compilers/active_record_relations.rb b/lib/tapioca/dsl/compilers/active_record_relations.rb index bf8897e30..fa20646e2 100644 --- a/lib/tapioca/dsl/compilers/active_record_relations.rb +++ b/lib/tapioca/dsl/compilers/active_record_relations.rb @@ -3,6 +3,7 @@ return unless defined?(ActiveRecord::Base) +require "tapioca/dsl/helpers/active_model_type_helper" require "tapioca/dsl/helpers/active_record_constants_helper" module Tapioca @@ -657,9 +658,30 @@ def create_common_methods ) when :find # From ActiveRecord::ConnectionAdapter::Quoting#quote, minus nil - id_types = "T.any(String, Symbol, ::ActiveSupport::Multibyte::Chars, T::Boolean, BigDecimal, Numeric, " \ - "::ActiveRecord::Type::Binary::Data, ::ActiveRecord::Type::Time::Value, Date, Time, " \ - "::ActiveSupport::Duration, T::Class[T.anything])" + id_types = [ + "String", + "Symbol", + "::ActiveSupport::Multibyte::Chars", + "T::Boolean", + "BigDecimal", + "Numeric", + "::ActiveRecord::Type::Binary::Data", + "::ActiveRecord::Type::Time::Value", + "Date", + "Time", + "::ActiveSupport::Duration", + "T::Class[T.anything]", + ].to_set + + if constant.table_exists? + primary_key_type = constant.type_for_attribute(constant.primary_key) + type = Tapioca::Dsl::Helpers::ActiveModelTypeHelper.type_for(primary_key_type) + type = RBIHelper.as_non_nilable_type(type) + id_types << type if type != "T.untyped" + end + + id_types = "T.any(#{id_types.to_a.join(", ")})" + array_type = if constant.try(:composite_primary_key?) "T::Array[T::Array[#{id_types}]]" else diff --git a/lib/tapioca/helpers/rbi_helper.rb b/lib/tapioca/helpers/rbi_helper.rb index 6ef444a14..a2249280a 100644 --- a/lib/tapioca/helpers/rbi_helper.rb +++ b/lib/tapioca/helpers/rbi_helper.rb @@ -96,6 +96,15 @@ def as_nilable_type(type) end end + sig { params(type: String).returns(String) } + def as_non_nilable_type(type) + if type.match(/\A(?:::)?T.nilable\((.+)\)\z/) + T.must(::Regexp.last_match(1)) + else + type + end + end + sig { params(name: String).returns(T::Boolean) } def valid_method_name?(name) # try to parse a method definition with this name diff --git a/spec/tapioca/dsl/compilers/active_record_relations_spec.rb b/spec/tapioca/dsl/compilers/active_record_relations_spec.rb index 1cb6d98dd..42d79a90a 100644 --- a/spec/tapioca/dsl/compilers/active_record_relations_spec.rb +++ b/spec/tapioca/dsl/compilers/active_record_relations_spec.rb @@ -48,8 +48,33 @@ class User end it "generates proper relation classes and modules" do + add_ruby_file("schema.rb", <<~RUBY) + ActiveRecord::Migration.suppress_messages do + ActiveRecord::Schema.define do + create_table :posts do |t| + end + end + end + RUBY + + add_ruby_file("custom_id.rb", <<~RUBY) + class CustomId < ActiveRecord::Type::Value + extend T::Sig + + sig { params(value: T.untyped).returns(T.nilable(CustomId)) } + def deserialize(value) + CustomId.new(value) unless value.nil? + end + + def serialize(value) + value + end + end + RUBY + add_ruby_file("post.rb", <<~RUBY) class Post < ActiveRecord::Base + attribute :id, CustomId.new end RUBY @@ -106,8 +131,8 @@ def fifth; end sig { returns(::Post) } def fifth!; end - sig { params(args: T.any(String, Symbol, ::ActiveSupport::Multibyte::Chars, T::Boolean, BigDecimal, Numeric, ::ActiveRecord::Type::Binary::Data, ::ActiveRecord::Type::Time::Value, Date, Time, ::ActiveSupport::Duration, T::Class[T.anything])).returns(::Post) } - sig { params(args: T::Array[T.any(String, Symbol, ::ActiveSupport::Multibyte::Chars, T::Boolean, BigDecimal, Numeric, ::ActiveRecord::Type::Binary::Data, ::ActiveRecord::Type::Time::Value, Date, Time, ::ActiveSupport::Duration, T::Class[T.anything])]).returns(T::Enumerable[::Post]) } + sig { params(args: T.any(String, Symbol, ::ActiveSupport::Multibyte::Chars, T::Boolean, BigDecimal, Numeric, ::ActiveRecord::Type::Binary::Data, ::ActiveRecord::Type::Time::Value, Date, Time, ::ActiveSupport::Duration, T::Class[T.anything], ::CustomId)).returns(::Post) } + sig { params(args: T::Array[T.any(String, Symbol, ::ActiveSupport::Multibyte::Chars, T::Boolean, BigDecimal, Numeric, ::ActiveRecord::Type::Binary::Data, ::ActiveRecord::Type::Time::Value, Date, Time, ::ActiveSupport::Duration, T::Class[T.anything], ::CustomId)]).returns(T::Enumerable[::Post]) } sig { params(args: NilClass, block: T.proc.params(object: ::Post).void).returns(T.nilable(::Post)) } def find(args = nil, &block); end diff --git a/spec/tapioca/helpers/rbi_helper_spec.rb b/spec/tapioca/helpers/rbi_helper_spec.rb index 50add629e..3a9af1d31 100644 --- a/spec/tapioca/helpers/rbi_helper_spec.rb +++ b/spec/tapioca/helpers/rbi_helper_spec.rb @@ -7,6 +7,15 @@ class Tapioca::RBIHelperSpec < Minitest::Spec include Tapioca::RBIHelper describe Tapioca::RBIHelper do + specify "as_non_nilable_type removes T.nilable() and ::T.nilable() if it's the outermost part of the string" do + T.bind(self, T.untyped) + + assert_equal(as_non_nilable_type("T.nilable(String)"), "String") + assert_equal(as_non_nilable_type("::T.nilable(String)"), "String") + assert_equal(as_non_nilable_type("String"), "String") + assert_equal(as_non_nilable_type("T.any(T.nilable(String), Integer)"), "T.any(T.nilable(String), Integer)") + end + it "accepts valid method names" do [ "f",