From cb9020ba548e2391d7f04bd431151fd52183e687 Mon Sep 17 00:00:00 2001 From: HoneyryderChuck Date: Wed, 27 Mar 2024 12:43:30 +0000 Subject: [PATCH] changing sslsocket typeddata from SSL to an intermediate structure wrapping SSL preparing for future approach involving using BIO_mem with something which quacks like an IO --- ext/openssl/ossl_ssl.c | 161 ++++++++++++++++++--------------- ext/openssl/ossl_ssl.h | 6 +- ext/openssl/ossl_ssl_session.c | 4 +- 3 files changed, 94 insertions(+), 77 deletions(-) diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c index c63f5868f..6dbda13b3 100644 --- a/ext/openssl/ossl_ssl.c +++ b/ext/openssl/ossl_ssl.c @@ -571,13 +571,13 @@ ossl_call_servername_cb(VALUE ary) ret_obj = rb_funcallv(cb, id_call, 1, &ary); if (rb_obj_is_kind_of(ret_obj, cSSLContext)) { - SSL *ssl; + rb_ssl_str *ssl; SSL_CTX *ctx2; ossl_sslctx_setup(ret_obj); GetSSL(ssl_obj, ssl); GetSSLCTX(ret_obj, ctx2); - SSL_set_SSL_CTX(ssl, ctx2); + SSL_set_SSL_CTX(ssl->ssl, ctx2); rb_ivar_set(ssl_obj, id_i_context, ret_obj); } else if (!NIL_P(ret_obj)) { ossl_raise(rb_eArgError, "servername_cb must return an " @@ -1551,19 +1551,21 @@ ssl_started(SSL *ssl) static void ossl_ssl_mark(void *ptr) { - SSL *ssl = ptr; - rb_gc_mark((VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_ptr_idx)); + rb_ssl_str *ssl = ptr; + rb_gc_mark((VALUE)SSL_get_ex_data(ssl->ssl, ossl_ssl_ex_ptr_idx)); // Note: this reference is stored as @verify_callback so we don't need to mark it. // However we do need to ensure GC compaction won't move it, hence why // we call rb_gc_mark here. - rb_gc_mark((VALUE)SSL_get_ex_data(ssl, ossl_ssl_ex_vcb_idx)); + rb_gc_mark((VALUE)SSL_get_ex_data(ssl->ssl, ossl_ssl_ex_vcb_idx)); } static void -ossl_ssl_free(void *ssl) +ossl_ssl_free(void *ptr) { - SSL_free(ssl); + rb_ssl_str *ssl = ptr; + SSL_free(ssl->ssl); + xfree(ssl); } const rb_data_type_t ossl_ssl_type = { @@ -1574,10 +1576,13 @@ const rb_data_type_t ossl_ssl_type = { 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, }; + static VALUE ossl_ssl_s_alloc(VALUE klass) { - return TypedData_Wrap_Struct(klass, &ossl_ssl_type, NULL); + rb_ssl_str *ssl = xmalloc(sizeof(rb_ssl_str)); + ssl->ssl = NULL; + return TypedData_Wrap_Struct(klass, &ossl_ssl_type, ssl); } static VALUE @@ -1623,11 +1628,12 @@ static VALUE ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) { VALUE io, v_ctx, verify_cb; + rb_ssl_str *ssl_ex; SSL *ssl; SSL_CTX *ctx; - TypedData_Get_Struct(self, SSL, &ossl_ssl_type, ssl); - if (ssl) + TypedData_Get_Struct(self, rb_ssl_str, &ossl_ssl_type, ssl_ex); + if (ssl_ex->ssl) ossl_raise(eSSLError, "SSL already initialized"); if (rb_scan_args(argc, argv, "11", &io, &v_ctx) == 1) @@ -1639,13 +1645,13 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) if (rb_respond_to(io, rb_intern("nonblock="))) rb_funcall(io, rb_intern("nonblock="), 1, Qtrue); - Check_Type(io, T_FILE); + // Check_Type(io, T_FILE); rb_ivar_set(self, id_i_io, io); ssl = SSL_new(ctx); if (!ssl) ossl_raise(eSSLError, NULL); - RTYPEDDATA_DATA(self) = ssl; + ((rb_ssl_str *) RTYPEDDATA_DATA(self))->ssl = ssl; SSL_set_ex_data(ssl, ossl_ssl_ex_ptr_idx, (void *)self); SSL_set_info_callback(ssl, ssl_info_cb); @@ -1674,19 +1680,26 @@ static VALUE ossl_ssl_setup(VALUE self) { VALUE io; - SSL *ssl; + rb_ssl_str *ssl; rb_io_t *fptr; GetSSL(self, ssl); - if (ssl_started(ssl)) + if (ssl_started(ssl->ssl)) return Qtrue; io = rb_attr_get(self, id_i_io); - GetOpenFile(io, fptr); - rb_io_check_readable(fptr); - rb_io_check_writable(fptr); - if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io)))) - ossl_raise(eSSLError, "SSL_set_fd"); + + if (TYPE(io) == T_FILE) { + GetOpenFile(io, fptr); + rb_io_check_readable(fptr); + rb_io_check_writable(fptr); + if (!SSL_set_fd(ssl->ssl, TO_SOCKET(rb_io_descriptor(io)))) + ossl_raise(eSSLError, "SSL_set_fd"); + } else { + // something which quacks like an IO + + } + return Qtrue; } @@ -1779,7 +1792,7 @@ io_wait_readable(VALUE io) static VALUE ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) { - SSL *ssl; + rb_ssl_str *ssl; int ret, ret2; VALUE cb_state; int nonblock = opts != Qfalse; @@ -1790,7 +1803,7 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) VALUE io = rb_attr_get(self, id_i_io); for (;;) { - ret = func(ssl); + ret = func(ssl->ssl); cb_state = rb_attr_get(self, ID_callback_state); if (!NIL_P(cb_state)) { @@ -1802,7 +1815,7 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) if (ret > 0) break; - switch ((ret2 = ssl_get_error(ssl, ret))) { + switch ((ret2 = ssl_get_error(ssl->ssl, ret))) { case SSL_ERROR_WANT_WRITE: if (no_exception_p(opts)) { return sym_wait_writable; } write_would_block(nonblock); @@ -1828,7 +1841,7 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) if (ERR_GET_LIB(err) == ERR_LIB_SSL && ERR_GET_REASON(err) == SSL_R_CERTIFICATE_VERIFY_FAILED) { const char *err_msg = ERR_reason_error_string(err), - *verify_msg = X509_verify_cert_error_string(SSL_get_verify_result(ssl)); + *verify_msg = X509_verify_cert_error_string(SSL_get_verify_result(ssl->ssl)); if (!err_msg) err_msg = "(null)"; if (!verify_msg) @@ -1844,7 +1857,7 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) ret2, errno, peeraddr_ip_str(self), - SSL_state_string_long(ssl), + SSL_state_string_long(ssl->ssl), error_append); } } @@ -1950,7 +1963,7 @@ ossl_ssl_accept_nonblock(int argc, VALUE *argv, VALUE self) static VALUE ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) { - SSL *ssl; + rb_ssl_str *ssl; int ilen; VALUE len, str; VALUE opts = Qnil; @@ -1961,7 +1974,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) rb_scan_args(argc, argv, "11", &len, &str); } GetSSL(self, ssl); - if (!ssl_started(ssl)) + if (!ssl_started(ssl->ssl)) rb_raise(eSSLError, "SSL session is not started yet"); ilen = NUM2INT(len); @@ -1982,8 +1995,8 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) rb_str_locktmp(str); for (;;) { - int nread = ossl_ssl_read_impl(ssl, str, ilen); - switch (ssl_get_error(ssl, nread)) { + int nread = ossl_ssl_read_impl(ssl->ssl, str, ilen); + switch (ssl_get_error(ssl->ssl, nread)) { case SSL_ERROR_NONE: rb_str_unlocktmp(str); rb_str_set_len(str, nread); @@ -2069,13 +2082,13 @@ ossl_ssl_read_nonblock(int argc, VALUE *argv, VALUE self) static VALUE ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) { - SSL *ssl; + rb_ssl_str *ssl; rb_io_t *fptr; int num, nonblock = opts != Qfalse; VALUE tmp; GetSSL(self, ssl); - if (!ssl_started(ssl)) + if (!ssl_started(ssl->ssl)) rb_raise(eSSLError, "SSL session is not started yet"); tmp = rb_str_new_frozen(StringValue(str)); @@ -2088,8 +2101,8 @@ ossl_ssl_write_internal(VALUE self, VALUE str, VALUE opts) return INT2FIX(0); for (;;) { - int nwritten = ossl_ssl_write_impl(ssl, tmp, num); - switch (ssl_get_error(ssl, nwritten)) { + int nwritten = ossl_ssl_write_impl(ssl->ssl, tmp, num); + switch (ssl_get_error(ssl->ssl, nwritten)) { case SSL_ERROR_NONE: return INT2NUM(nwritten); case SSL_ERROR_WANT_WRITE: @@ -2160,13 +2173,13 @@ ossl_ssl_write_nonblock(int argc, VALUE *argv, VALUE self) static VALUE ossl_ssl_stop(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; int ret; GetSSL(self, ssl); - if (!ssl_started(ssl)) + if (!ssl_started(ssl->ssl)) return Qnil; - ret = SSL_shutdown(ssl); + ret = SSL_shutdown(ssl->ssl); if (ret == 1) /* Have already received close_notify */ return Qnil; if (ret == 0) /* Sent close_notify, but we don't wait for reply */ @@ -2191,7 +2204,7 @@ ossl_ssl_stop(VALUE self) static VALUE ossl_ssl_get_cert(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; X509 *cert = NULL; GetSSL(self, ssl); @@ -2200,7 +2213,7 @@ ossl_ssl_get_cert(VALUE self) * Is this OpenSSL bug? Should add a ref? * TODO: Ask for. */ - cert = SSL_get_certificate(ssl); /* NO DUPs => DON'T FREE. */ + cert = SSL_get_certificate(ssl->ssl); /* NO DUPs => DON'T FREE. */ if (!cert) { return Qnil; @@ -2217,13 +2230,13 @@ ossl_ssl_get_cert(VALUE self) static VALUE ossl_ssl_get_peer_cert(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; X509 *cert = NULL; VALUE obj; GetSSL(self, ssl); - cert = SSL_get_peer_certificate(ssl); /* Adds a ref => Safe to FREE. */ + cert = SSL_get_peer_certificate(ssl->ssl); /* Adds a ref => Safe to FREE. */ if (!cert) { return Qnil; @@ -2243,7 +2256,7 @@ ossl_ssl_get_peer_cert(VALUE self) static VALUE ossl_ssl_get_peer_cert_chain(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; STACK_OF(X509) *chain; X509 *cert; VALUE ary; @@ -2251,7 +2264,7 @@ ossl_ssl_get_peer_cert_chain(VALUE self) GetSSL(self, ssl); - chain = SSL_get_peer_cert_chain(ssl); + chain = SSL_get_peer_cert_chain(ssl->ssl); if(!chain) return Qnil; num = sk_X509_num(chain); ary = rb_ary_new2(num); @@ -2273,11 +2286,11 @@ ossl_ssl_get_peer_cert_chain(VALUE self) static VALUE ossl_ssl_get_version(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; GetSSL(self, ssl); - return rb_str_new2(SSL_get_version(ssl)); + return rb_str_new2(SSL_get_version(ssl->ssl)); } /* @@ -2290,11 +2303,11 @@ ossl_ssl_get_version(VALUE self) static VALUE ossl_ssl_get_cipher(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; const SSL_CIPHER *cipher; GetSSL(self, ssl); - cipher = SSL_get_current_cipher(ssl); + cipher = SSL_get_current_cipher(ssl->ssl); return cipher ? ossl_ssl_cipher_to_ary(cipher) : Qnil; } @@ -2308,15 +2321,15 @@ ossl_ssl_get_cipher(VALUE self) static VALUE ossl_ssl_get_state(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; VALUE ret; GetSSL(self, ssl); - ret = rb_str_new2(SSL_state_string(ssl)); + ret = rb_str_new2(SSL_state_string(ssl->ssl)); if (ruby_verbose) { rb_str_cat2(ret, ": "); - rb_str_cat2(ret, SSL_state_string_long(ssl)); + rb_str_cat2(ret, SSL_state_string_long(ssl->ssl)); } return ret; } @@ -2330,11 +2343,11 @@ ossl_ssl_get_state(VALUE self) static VALUE ossl_ssl_pending(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; GetSSL(self, ssl); - return INT2NUM(SSL_pending(ssl)); + return INT2NUM(SSL_pending(ssl->ssl)); } /* @@ -2346,11 +2359,11 @@ ossl_ssl_pending(VALUE self) static VALUE ossl_ssl_session_reused(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; GetSSL(self, ssl); - return SSL_session_reused(ssl) ? Qtrue : Qfalse; + return SSL_session_reused(ssl->ssl) ? Qtrue : Qfalse; } /* @@ -2362,13 +2375,13 @@ ossl_ssl_session_reused(VALUE self) static VALUE ossl_ssl_set_session(VALUE self, VALUE arg1) { - SSL *ssl; + rb_ssl_str *ssl; SSL_SESSION *sess; GetSSL(self, ssl); GetSSLSession(arg1, sess); - if (SSL_set_session(ssl, sess) != 1) + if (SSL_set_session(ssl->ssl, sess) != 1) ossl_raise(eSSLError, "SSL_set_session"); return arg1; @@ -2384,7 +2397,7 @@ ossl_ssl_set_session(VALUE self, VALUE arg1) static VALUE ossl_ssl_set_hostname(VALUE self, VALUE arg) { - SSL *ssl; + rb_ssl_str *ssl; char *hostname = NULL; GetSSL(self, ssl); @@ -2392,7 +2405,7 @@ ossl_ssl_set_hostname(VALUE self, VALUE arg) if (!NIL_P(arg)) hostname = StringValueCStr(arg); - if (!SSL_set_tlsext_host_name(ssl, hostname)) + if (!SSL_set_tlsext_host_name(ssl->ssl, hostname)) ossl_raise(eSSLError, NULL); /* for SSLSocket#hostname */ @@ -2413,11 +2426,11 @@ ossl_ssl_set_hostname(VALUE self, VALUE arg) static VALUE ossl_ssl_get_verify_result(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; GetSSL(self, ssl); - return LONG2NUM(SSL_get_verify_result(ssl)); + return LONG2NUM(SSL_get_verify_result(ssl->ssl)); } /* @@ -2430,18 +2443,18 @@ ossl_ssl_get_verify_result(VALUE self) static VALUE ossl_ssl_get_finished(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; char sizer[1], *buf; size_t len; GetSSL(self, ssl); - len = SSL_get_finished(ssl, sizer, 0); + len = SSL_get_finished(ssl->ssl, sizer, 0); if (len == 0) return Qnil; buf = ALLOCA_N(char, len); - SSL_get_finished(ssl, buf, len); + SSL_get_finished(ssl->ssl, buf, len); return rb_str_new(buf, len); } @@ -2455,18 +2468,18 @@ ossl_ssl_get_finished(VALUE self) static VALUE ossl_ssl_get_peer_finished(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; char sizer[1], *buf; size_t len; GetSSL(self, ssl); - len = SSL_get_peer_finished(ssl, sizer, 0); + len = SSL_get_peer_finished(ssl->ssl, sizer, 0); if (len == 0) return Qnil; buf = ALLOCA_N(char, len); - SSL_get_peer_finished(ssl, buf, len); + SSL_get_peer_finished(ssl->ssl, buf, len); return rb_str_new(buf, len); } @@ -2484,12 +2497,12 @@ ossl_ssl_get_peer_finished(VALUE self) static VALUE ossl_ssl_get_client_ca_list(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; STACK_OF(X509_NAME) *ca; GetSSL(self, ssl); - ca = SSL_get_client_CA_list(ssl); + ca = SSL_get_client_CA_list(ssl->ssl); return ossl_x509name_sk2ary(ca); } @@ -2504,13 +2517,13 @@ ossl_ssl_get_client_ca_list(VALUE self) static VALUE ossl_ssl_npn_protocol(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; const unsigned char *out; unsigned int outlen; GetSSL(self, ssl); - SSL_get0_next_proto_negotiated(ssl, &out, &outlen); + SSL_get0_next_proto_negotiated(ssl->ssl, &out, &outlen); if (!outlen) return Qnil; else @@ -2528,13 +2541,13 @@ ossl_ssl_npn_protocol(VALUE self) static VALUE ossl_ssl_alpn_protocol(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; const unsigned char *out; unsigned int outlen; GetSSL(self, ssl); - SSL_get0_alpn_selected(ssl, &out, &outlen); + SSL_get0_alpn_selected(ssl->ssl, &out, &outlen); if (!outlen) return Qnil; else @@ -2550,7 +2563,7 @@ ossl_ssl_alpn_protocol(VALUE self) static VALUE ossl_ssl_export_keying_material(int argc, VALUE *argv, VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; VALUE str; VALUE label; VALUE length; @@ -2576,7 +2589,7 @@ ossl_ssl_export_keying_material(int argc, VALUE *argv, VALUE self) ctx = (unsigned char *)RSTRING_PTR(context); ctx_len = RSTRING_LEN(context); } - ret = SSL_export_keying_material(ssl, p, len, (char *)RSTRING_PTR(label), + ret = SSL_export_keying_material(ssl->ssl, p, len, (char *)RSTRING_PTR(label), RSTRING_LENINT(label), ctx, ctx_len, use_ctx); if (ret == 0 || ret == -1) { ossl_raise(eSSLError, "SSL_export_keying_material"); @@ -2593,11 +2606,11 @@ ossl_ssl_export_keying_material(int argc, VALUE *argv, VALUE self) static VALUE ossl_ssl_tmp_key(VALUE self) { - SSL *ssl; + rb_ssl_str *ssl; EVP_PKEY *key; GetSSL(self, ssl); - if (!SSL_get_server_tmp_key(ssl, &key)) + if (!SSL_get_server_tmp_key(ssl->ssl, &key)) return Qnil; return ossl_pkey_new(key); } diff --git a/ext/openssl/ossl_ssl.h b/ext/openssl/ossl_ssl.h index 535c56097..cde04959f 100644 --- a/ext/openssl/ossl_ssl.h +++ b/ext/openssl/ossl_ssl.h @@ -10,8 +10,12 @@ #if !defined(_OSSL_SSL_H_) #define _OSSL_SSL_H_ +typedef struct rb_ssl { + SSL *ssl; +} rb_ssl_str; + #define GetSSL(obj, ssl) do { \ - TypedData_Get_Struct((obj), SSL, &ossl_ssl_type, (ssl)); \ + TypedData_Get_Struct((obj), rb_ssl_str, &ossl_ssl_type, (ssl)); \ if (!(ssl)) { \ ossl_raise(rb_eRuntimeError, "SSL is not initialized"); \ } \ diff --git a/ext/openssl/ossl_ssl_session.c b/ext/openssl/ossl_ssl_session.c index c5df902c6..5571cae01 100644 --- a/ext/openssl/ossl_ssl_session.c +++ b/ext/openssl/ossl_ssl_session.c @@ -44,11 +44,11 @@ ossl_ssl_session_initialize(VALUE self, VALUE arg1) ossl_raise(eSSLSession, "SSL Session already initialized"); if (rb_obj_is_instance_of(arg1, cSSLSocket)) { - SSL *ssl; + rb_ssl_str *ssl; GetSSL(arg1, ssl); - if ((ctx = SSL_get1_session(ssl)) == NULL) + if ((ctx = SSL_get1_session(ssl->ssl)) == NULL) ossl_raise(eSSLSession, "no session available"); } else {