diff --git a/lib/tzinfo/data_source.rb b/lib/tzinfo/data_source.rb index 6288aec2..3dbac86b 100644 --- a/lib/tzinfo/data_source.rb +++ b/lib/tzinfo/data_source.rb @@ -247,6 +247,16 @@ def country_codes raise_invalid_data_source('country_codes') end + # Loads all timezone and country data into memory. This may be desirable in + # production environments to improve copy-on-write performance and to + # avoid flushing the constant cache every time a new timezone or country + # is loaded from {DataSources::RubyDataSource}. + def preload! + timezone_identifiers.each {|identifier| load_timezone_info(identifier) } + country_codes.each {|code| load_country_info(code) } + nil + end + # @return [String] a description of the {DataSource}. def to_s "Default DataSource" diff --git a/test/tc_data_source.rb b/test/tc_data_source.rb index f38686e2..28103e12 100644 --- a/test/tc_data_source.rb +++ b/test/tc_data_source.rb @@ -87,6 +87,37 @@ def call_lookup_country_info(hash, code, encoding = Encoding::UTF_8) end end + class PreloadTestDataSource < GetTimezoneIdentifiersTestDataSource + attr_reader :country_codes_called + attr_reader :loaded_timezones + attr_reader :loaded_countries + + def initialize(data_timezone_identifiers, linked_timezone_identifiers, country_codes) + super(data_timezone_identifiers, linked_timezone_identifiers) + @country_codes = country_codes + @country_codes_called = 0 + @loaded_timezones = [] + @loaded_countries = [] + end + + protected + + def country_codes + @country_codes_called += 1 + @country_codes + end + + def load_timezone_info(identifier) + @loaded_timezones << identifier + DataSources::TimezoneInfo.new(identifier) + end + + def load_country_info(code) + @loaded_countries << code + DataSources::CountryInfo.new(code, "Country #{code}", []) + end + end + def setup @orig_data_source = DataSource.get DataSource.set(InitDataSource.new) @@ -549,4 +580,19 @@ def test_lookup_country_info_case assert_equal(Encoding::UTF_8, error.message.encoding) end end + + def test_preload + data_timezone_identifiers = ['Data/Zone1', 'Data/Zone2'] + linked_timezone_identifiers = ['Linked/Zone1', 'Linked/Zone2'] + all_timezone_identifiers = data_timezone_identifiers + linked_timezone_identifiers + country_codes = ['AA', 'BB', 'CC'] + ds = PreloadTestDataSource.new(data_timezone_identifiers, linked_timezone_identifiers, country_codes) + assert_nil(ds.preload!) + assert_equal(1, ds.data_timezone_identifiers_called) + assert_equal(1, ds.linked_timezone_identifiers_called) + assert_equal(all_timezone_identifiers, ds.loaded_timezones) + assert_equal(all_timezone_identifiers, ds.instance_variable_get(:@timezone_identifiers)) + assert_equal(1, ds.country_codes_called) + assert_equal(country_codes, ds.loaded_countries) + end end