diff --git a/lib/jwt/jwk.rb b/lib/jwt/jwk.rb index 04aee444..01d8167a 100644 --- a/lib/jwt/jwk.rb +++ b/lib/jwt/jwk.rb @@ -1,36 +1,50 @@ # frozen_string_literal: true -require_relative 'jwk/key_abstract' -require_relative 'jwk/rsa' -require_relative 'jwk/hmac' require_relative 'jwk/key_finder' module JWT module JWK - MAPPINGS = { - 'RSA' => ::JWT::JWK::RSA, - OpenSSL::PKey::RSA => ::JWT::JWK::RSA, - 'oct' => ::JWT::JWK::HMAC, - String => ::JWT::JWK::HMAC - }.freeze - class << self def import(jwk_data) jwk_kty = jwk_data[:kty] || jwk_data['kty'] raise JWT::JWKError, 'Key type (kty) not provided' unless jwk_kty - MAPPINGS.fetch(jwk_kty.to_s) do |kty| + mappings.fetch(jwk_kty.to_s) do |kty| raise JWT::JWKError, "Key type #{kty} not supported" end.import(jwk_data) end def create_from(keypair) - MAPPINGS.fetch(keypair.class) do |klass| + mappings.fetch(keypair.class) do |klass| raise JWT::JWKError, "Cannot create JWK from a #{klass.name}" end.new(keypair) end + def classes + @mappings = nil # reset the cached mappings + @classes ||= [] + end + alias new create_from + + private + + def mappings + @mappings ||= generate_mappings + end + + def generate_mappings + classes.each_with_object({}) do |klass, hash| + next unless klass.const_defined?('KTYS') + Array(klass::KTYS).each do |kty| + hash[kty] = klass + end + end + end end end end + +require_relative 'jwk/key_base' +require_relative 'jwk/rsa' +require_relative 'jwk/hmac' diff --git a/lib/jwt/jwk/hmac.rb b/lib/jwt/jwk/hmac.rb index 10c4fec1..61839e97 100644 --- a/lib/jwt/jwk/hmac.rb +++ b/lib/jwt/jwk/hmac.rb @@ -2,8 +2,9 @@ module JWT module JWK - class HMAC < KeyAbstract + class HMAC < KeyBase KTY = 'oct'.freeze + KTYS = [KTY, String].freeze def initialize(keypair, kid = nil) raise ArgumentError, 'keypair must be of type String' unless keypair.is_a?(String) diff --git a/lib/jwt/jwk/key_abstract.rb b/lib/jwt/jwk/key_abstract.rb deleted file mode 100644 index 1251e2bc..00000000 --- a/lib/jwt/jwk/key_abstract.rb +++ /dev/null @@ -1,36 +0,0 @@ -# frozen_string_literal: true - -module JWT - module JWK - class KeyAbstract - attr_reader :keypair, :kid - - def initialize(keypair, kid = nil) - @keypair = keypair - @kid = kid - end - - def private? - raise NotImplementedError, "#{self.class} has not implemented method '#{__method__}'" - end - - def public_key - raise NotImplementedError, "#{self.class} has not implemented method '#{__method__}'" - end - - def export(_options = {}) - raise NotImplementedError, "#{self.class} has not implemented method '#{__method__}'" - end - - protected - - attr_writer :kid - - class << self - def import(_jwk_data) - raise NotImplementedError, "#{self.class} has not implemented method '#{__method__}'" - end - end - end - end -end diff --git a/lib/jwt/jwk/key_base.rb b/lib/jwt/jwk/key_base.rb new file mode 100644 index 00000000..46619a79 --- /dev/null +++ b/lib/jwt/jwk/key_base.rb @@ -0,0 +1,18 @@ +# frozen_string_literal: true + +module JWT + module JWK + class KeyBase + attr_reader :keypair, :kid + + def initialize(keypair, kid = nil) + @keypair = keypair + @kid = kid + end + + def self.inherited(klass) + ::JWT::JWK.classes << klass + end + end + end +end diff --git a/lib/jwt/jwk/rsa.rb b/lib/jwt/jwk/rsa.rb index 975131d4..bd66d509 100644 --- a/lib/jwt/jwk/rsa.rb +++ b/lib/jwt/jwk/rsa.rb @@ -2,15 +2,15 @@ module JWT module JWK - class RSA < KeyAbstract - BINARY = 2 - KTY = 'RSA'.freeze + class RSA < KeyBase + BINARY = 2 + KTY = 'RSA'.freeze + KTYS = [KTY, OpenSSL::PKey::RSA].freeze RSA_KEY_ELEMENTS = %i[n e d p q dp dq qi].freeze def initialize(keypair, kid = nil) raise ArgumentError, 'keypair must be of type OpenSSL::PKey::RSA' unless keypair.is_a?(OpenSSL::PKey::RSA) - super - self.kid ||= generate_kid + super(keypair, kid || generate_kid(keypair.public_key)) end def private? @@ -36,7 +36,7 @@ def export(options = {}) private - def generate_kid + def generate_kid(public_key) sequence = OpenSSL::ASN1::Sequence([OpenSSL::ASN1::Integer.new(public_key.n), OpenSSL::ASN1::Integer.new(public_key.e)]) OpenSSL::Digest::SHA256.hexdigest(sequence.to_der)