Skip to content

Commit

Permalink
Use type_for_attribute to determine primary_key for find
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-tan committed Apr 15, 2024
1 parent f00b6c1 commit 9c563a2
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 5 deletions.
28 changes: 25 additions & 3 deletions lib/tapioca/dsl/compilers/active_record_relations.rb
Expand Up @@ -3,6 +3,7 @@

return unless defined?(ActiveRecord::Base)

require "tapioca/dsl/helpers/active_model_type_helper"
require "tapioca/dsl/helpers/active_record_constants_helper"

module Tapioca
Expand Down Expand Up @@ -657,9 +658,30 @@ def create_common_methods
)
when :find
# From ActiveRecord::ConnectionAdapter::Quoting#quote, minus nil
id_types = "T.any(String, Symbol, ::ActiveSupport::Multibyte::Chars, T::Boolean, BigDecimal, Numeric, " \
"::ActiveRecord::Type::Binary::Data, ::ActiveRecord::Type::Time::Value, Date, Time, " \
"::ActiveSupport::Duration, T::Class[T.anything])"
id_types = [
"String",
"Symbol",
"::ActiveSupport::Multibyte::Chars",
"T::Boolean",
"BigDecimal",
"Numeric",
"::ActiveRecord::Type::Binary::Data",
"::ActiveRecord::Type::Time::Value",
"Date",
"Time",
"::ActiveSupport::Duration",
"T::Class[T.anything]",
].to_set

if constant.table_exists?
primary_key_type = constant.type_for_attribute(constant.primary_key)
type = Tapioca::Dsl::Helpers::ActiveModelTypeHelper.type_for(primary_key_type)
type = RBIHelper.as_non_nilable_type(type)
id_types << type if type != "T.untyped"
end

id_types = "T.any(#{id_types.to_a.join(", ")})"

array_type = if constant.try(:composite_primary_key?)
"T::Array[T::Array[#{id_types}]]"
else
Expand Down
9 changes: 9 additions & 0 deletions lib/tapioca/helpers/rbi_helper.rb
Expand Up @@ -96,6 +96,15 @@ def as_nilable_type(type)
end
end

sig { params(type: String).returns(String) }
def as_non_nilable_type(type)
if type.match(/\A(?:::)?T.nilable\((.+)\)\z/)
T.must(::Regexp.last_match(1))
else
type
end
end

sig { params(name: String).returns(T::Boolean) }
def valid_method_name?(name)
# try to parse a method definition with this name
Expand Down
29 changes: 27 additions & 2 deletions spec/tapioca/dsl/compilers/active_record_relations_spec.rb
Expand Up @@ -48,8 +48,33 @@ class User
end

it "generates proper relation classes and modules" do
add_ruby_file("schema.rb", <<~RUBY)
ActiveRecord::Migration.suppress_messages do
ActiveRecord::Schema.define do
create_table :posts do |t|
end
end
end
RUBY

add_ruby_file("custom_id.rb", <<~RUBY)
class CustomId < ActiveRecord::Type::Value
extend T::Sig
sig { params(value: T.untyped).returns(T.nilable(CustomId)) }
def deserialize(value)
CustomId.new(value) unless value.nil?
end
def serialize(value)
value
end
end
RUBY

add_ruby_file("post.rb", <<~RUBY)
class Post < ActiveRecord::Base
attribute :id, CustomId.new
end
RUBY

Expand Down Expand Up @@ -106,8 +131,8 @@ def fifth; end
sig { returns(::Post) }
def fifth!; end
sig { params(args: T.any(String, Symbol, ::ActiveSupport::Multibyte::Chars, T::Boolean, BigDecimal, Numeric, ::ActiveRecord::Type::Binary::Data, ::ActiveRecord::Type::Time::Value, Date, Time, ::ActiveSupport::Duration, T::Class[T.anything])).returns(::Post) }
sig { params(args: T::Array[T.any(String, Symbol, ::ActiveSupport::Multibyte::Chars, T::Boolean, BigDecimal, Numeric, ::ActiveRecord::Type::Binary::Data, ::ActiveRecord::Type::Time::Value, Date, Time, ::ActiveSupport::Duration, T::Class[T.anything])]).returns(T::Enumerable[::Post]) }
sig { params(args: T.any(String, Symbol, ::ActiveSupport::Multibyte::Chars, T::Boolean, BigDecimal, Numeric, ::ActiveRecord::Type::Binary::Data, ::ActiveRecord::Type::Time::Value, Date, Time, ::ActiveSupport::Duration, T::Class[T.anything], ::CustomId)).returns(::Post) }
sig { params(args: T::Array[T.any(String, Symbol, ::ActiveSupport::Multibyte::Chars, T::Boolean, BigDecimal, Numeric, ::ActiveRecord::Type::Binary::Data, ::ActiveRecord::Type::Time::Value, Date, Time, ::ActiveSupport::Duration, T::Class[T.anything], ::CustomId)]).returns(T::Enumerable[::Post]) }
sig { params(args: NilClass, block: T.proc.params(object: ::Post).void).returns(T.nilable(::Post)) }
def find(args = nil, &block); end
Expand Down
9 changes: 9 additions & 0 deletions spec/tapioca/helpers/rbi_helper_spec.rb
Expand Up @@ -7,6 +7,15 @@ class Tapioca::RBIHelperSpec < Minitest::Spec
include Tapioca::RBIHelper

describe Tapioca::RBIHelper do
specify "as_non_nilable_type removes T.nilable() and ::T.nilable() if it's the outermost part of the string" do
T.bind(self, T.untyped)

assert_equal(as_non_nilable_type("T.nilable(String)"), "String")
assert_equal(as_non_nilable_type("::T.nilable(String)"), "String")
assert_equal(as_non_nilable_type("String"), "String")
assert_equal(as_non_nilable_type("T.any(T.nilable(String), Integer)"), "T.any(T.nilable(String), Integer)")
end

it "accepts valid method names" do
[
"f",
Expand Down

0 comments on commit 9c563a2

Please sign in to comment.