From 63f1a4a60095ce5dd7679d79a149c66e19c4411d Mon Sep 17 00:00:00 2001 From: Joakim Antman Date: Thu, 1 Oct 2020 22:48:50 +0300 Subject: [PATCH] Trying to split the import and export --- lib/jwt/jwk/rsa.rb | 79 ++++++++++++++++++++++++++++------------------ 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/lib/jwt/jwk/rsa.rb b/lib/jwt/jwk/rsa.rb index 457b6de7..536736ad 100644 --- a/lib/jwt/jwk/rsa.rb +++ b/lib/jwt/jwk/rsa.rb @@ -5,8 +5,9 @@ module JWK class RSA < KeyAbstract attr_reader :keypair - BINARY = 2 - KTY = 'RSA'.freeze + BINARY = 2 + KTY = 'RSA'.freeze + RSA_KEY_ELEMENTS = %i[n e d p q dp dq qi].freeze def initialize(keypair, kid = nil) raise ArgumentError, 'keypair must be of type OpenSSL::PKey::RSA' unless keypair.is_a?(OpenSSL::PKey::RSA) @@ -28,16 +29,22 @@ def kid end def export(options = {}) - ret = { + exported_hash = { kty: KTY, n: encode_open_ssl_bn(public_key.n), e: encode_open_ssl_bn(public_key.e), kid: kid } - return ret unless private? && options[:include_private] == true + return exported_hash unless private? && options[:include_private] == true - ret.merge( + append_private_parts(exported_hash) + end + + private + + def append_private_parts(the_hash) + the_hash.merge( d: encode_open_ssl_bn(keypair.d), p: encode_open_ssl_bn(keypair.p), q: encode_open_ssl_bn(keypair.q), @@ -47,8 +54,6 @@ def export(options = {}) ) end - private - def generate_kid sequence = OpenSSL::ASN1::Sequence([OpenSSL::ASN1::Integer.new(public_key.n), OpenSSL::ASN1::Integer.new(public_key.e)]) @@ -61,42 +66,54 @@ def encode_open_ssl_bn(key_part) class << self def import(jwk_data) - self.new(rsa_pkey(*jwk_attrs(jwk_data, :n, :e, :d, :p, :q, :dp, :dq, :qi)), jwk_data[:kid]) + pkey_params = jwk_attributes(jwk_data, *RSA_KEY_ELEMENTS) do |value| + decode_open_ssl_bn(value) + end + kid = jwk_attributes(jwk_data, :kid)[:kid] + self.new(rsa_pkey(pkey_params), kid) end private - def jwk_attrs(jwk_data, *attrs) - attrs.map do |attr| - decode_open_ssl_bn(jwk_data[attr] || jwk_data[attr.to_s]) + def jwk_attributes(jwk_data, *attributes) + attributes.each_with_object({}) do |attribute, hash| + value = jwk_data[attribute] || jwk_data[attribute.to_s] + value = yield(value) if block_given? + hash[attribute] = value end end - def rsa_pkey(jwk_n, jwk_e, jwk_d, jwk_p, jwk_q, jwk_dp, jwk_dq, jwk_qi) - raise JWT::JWKError, 'Key format is invalid for RSA' unless jwk_n && jwk_e - - key = OpenSSL::PKey::RSA.new - - if key.respond_to?(:set_key) - key.set_key(jwk_n, jwk_e, jwk_d) - key.set_factors(jwk_p, jwk_q) if jwk_p && jwk_q - key.set_crt_params(jwk_dp, jwk_dq, jwk_qi) if jwk_dp && jwk_dq && jwk_qi - else - key.n = jwk_n - key.e = jwk_e - key.d = jwk_d if jwk_d - key.p = jwk_p if jwk_p - key.q = jwk_q if jwk_q - key.dmp1 = jwk_dp if jwk_dp - key.dmq1 = jwk_dq if jwk_dq - key.iqmp = jwk_qi if jwk_qi - end + def rsa_pkey(rsa_parameters) + raise JWT::JWKError, 'Key format is invalid for RSA' unless rsa_parameters[:n] && rsa_parameters[:e] + + populate_key(OpenSSL::PKey::RSA.new, rsa_parameters) + end - key + if OpenSSL::PKey::RSA.new.respond_to?(:set_key) + def populate_key(rsa_key, rsa_parameters) + rsa_key.set_key(rsa_parameters[:n], rsa_parameters[:e], rsa_parameters[:d]) + rsa_key.set_factors(rsa_parameters[:p], rsa_parameters[:q]) if rsa_parameters[:p] && rsa_parameters[:q] + rsa_key.set_crt_params(rsa_parameters[:dp], rsa_parameters[:dq], rsa_parameters[:qi]) if rsa_parameters[:dp] && rsa_parameters[:dq] && rsa_parameters[:qi] + rsa_key + end + else + def populate_key(rsa_key, rsa_parameters) # rubocop:disable Metrics/CyclomaticComplexity + rsa_key.n = rsa_parameters[:n] + rsa_key.e = rsa_parameters[:e] + rsa_key.d = rsa_parameters[:d] if rsa_parameters[:d] + rsa_key.p = rsa_parameters[:p] if rsa_parameters[:p] + rsa_key.q = rsa_parameters[:q] if rsa_parameters[:q] + rsa_key.dmp1 = rsa_parameters[:dp] if rsa_parameters[:dp] + rsa_key.dmq1 = rsa_parameters[:dq] if rsa_parameters[:dq] + rsa_key.iqmp = rsa_parameters[:qi] if rsa_parameters[:qi] + + rsa_key + end end def decode_open_ssl_bn(jwk_data) return nil if jwk_data.nil? + OpenSSL::BN.new(::Base64.urlsafe_decode64(jwk_data), BINARY) end end