1 classdef Algorithm < Greedy.Interface
2 %
default greedy basis generation
class
4 % This implements an
abstract version of the
default greedy algorithm
for
5 % reduced basis and empirical interpolation basis generation. The actual
6 % implementation of the extension and checking and error evaluation routines
7 % is delegated to an underlying
object implementing a Greedy.Plugin.Interface. See
8 % \ref
basisgen for example usage information.
10 % The entry point
for the greedy algorithm is the method basis_extension().
11 % See there
for more details on the algorithm
14 % cell array of fields to be copied to the detailed data
leaf node instance
15 % during init_basis().
16 info_fields = {
'M_train',
'M_validation', ...
17 'stop_epsilon',
'stop_max_val_train_ratio', ...
18 'val_train_ratio_seq_for_break', ...
23 % an
object of type SnapshotsGenerator.Cached
25 % This
object is inherited from the #detailed_extension.
29 properties(SetAccess=
private, Dependent)
41 properties(SetAccess =
private)
48 %
double specifying the maximum error indicator
for which the basis
49 % generation shall be stopped.
52 % integer specifying the number of seconds after which the basis generation
56 % positive
double value specifying the maximum ratio between the maximum
57 % error indicator over the validation paramter
set and the maximum error
58 % over the trainining parameter
set,
for which the basis generation is
61 % Note, that #M_validation needs to be non-empty, if this value is set to
62 % something different than 'inf'.
63 stop_max_val_train_ratio = inf;
65 % number of subsequent cases where the max_err_val_train_ratio is too high
66 % that is needed in order to
break with a
67 % Greedy.Info.Base.stop_max_val_train_ratio condition.
69 % The
default is
'1', but especially
for basis generation with
71 % at least
'2', because
this algorithms can produce very strange basis
74 val_train_ratio_seq_for_break = 1;
77 % `M_{\text{train}} \subset \cal M`
78 M_train = ParameterSampling.Uniform(1);
81 % `M_{\text{val}} \subset \cal M`
86 function bgd =
Algorithm(detailed_extension, M_train, initial_val_param_set)
87 % constructor
for a greedy basis generation algorithm
89 % This constructs a greedy basis generation algorithm
using the following
94 % training parameter
set.
95 % initial_val_param_set: an
object of type ParameterSampling.Interface specifying an initial
96 % validation parameter
set.
98 bgd.M_train = M_train;
101 bgd.M_validation = initial_val_param_set;
103 bgd.detailed_extension = detailed_extension;
106 function enabled = enable_validation(
this, detailed_data)
107 %
function enabled = enable_validation(
this, detailed_data)
108 % indicates whether the validation routines shall be executed
110 % This method returns
true when a validation parameter
set is given and the
111 %
property #stop_max_val_train_ratio is not
set to
'inf'
114 % detailed_data: of type .Greedy.DataTree.Detailed.ILeafNode
117 % enabled:
boolean value
118 enabled = ~isempty(detailed_data) && ~isempty(get_field(detailed_data,
'M_validation')) ...
119 && ~isinf(get_field(detailed_data,
'stop_max_val_train_ratio'));
122 function btype =
get.generated_basis_type(
this)
123 btype = this.detailed_extension.generated_basis_type;
126 function detailed_gen =
get.detailed_gen(
this)
127 detailed_gen = this.detailed_extension.generator;
130 function id =
get.id(
this)
131 id = this.detailed_extension.id;
134 function detailed_data_tree = basis_extension(
this, rmodel, detailed_data_tree, checkpoint)
135 %
function detailed_data_tree = basis_extension(
this, rmodel, detailed_data_tree, checkpoint)
136 % greedy basis generation extension
139 %
"error_indicators()" routine, the method prepare() should have been
141 % @pre The argument 'detailed_data_tree' should be initialized via by the
142 % init_basis() method.
144 % During each extension step, the methods from the underlying
#detailed_extension
145 % are called in
this order:
146 % -# The maximum error indicators are gather by calling
147 % @ref Greedy.Plugin.Interface.error_indicators() "error_indicators()"
148 % for the training set parameters, and if existent also vor the
149 % validation set parameters.
150 % -# Before the actual basis extension, we check whether any termination
151 % condition has been reached by computing some default checks
152 % documented below and extension specific checks in
153 % @ref Greedy.Plugin.Interface.pre_check_for_end() "pre_check_for_end()".
154 % -# For the worst approximated parameter a basis extension call is
155 % initiated by calling the method
156 % @ref Greedy.Plugin.Interface.basis_extension() "basis_extension()".
159 % By default this methods checks
160 % - whether maximum value for the error indicator dropped below the
161 % value of 'get_field(detailed_data, ''''stop_epsilon'''')',
162 % - whether the timeout barrier #stop_timeout has been reached, or
163 % - whether the ratio between the maximum error indicators over the
164 % training set and the validation set exceeds the barrier of
165 % 'get_field(detailed_data, ''''stop_max_val_train_ratio'''')'
168 % detailed_data_tree: either a struct containing high dimensional model data
169 % needed to execute detailed simulations or an object
170 % of type Greedy.DataTree.Detailed.INode in case of resume from a
172 % checkpoint: object of type Greedy.Checkpoint specifying a given point in the
173 % algorithm where it basis generation can be resumed.
176 % detailed_data_tree: object of type Greedy.DataTree.Detailed.INode storing the
177 % reduced basis information in the leaf nodes and
178 % information on the reduced basis generation in every
181 if nargin == 3 || isempty(checkpoint)
182 checkpoint = Greedy.Checkpoint;
185 basetoc =
get(checkpoint,
'toc', 0);
186 % prepare(
this, rmodel, detailed_data_tree);
190 dd_leaf = get_active_leaf(detailed_data_tree, rmodel);
191 cur_M_train = get_field(dd_leaf,
'M_train');
192 cur_M_validation = get_field(dd_leaf,
'M_validation');
194 % ATTENTION: We assume, that prepare has been called beforehand,
196 % prepare(
this, rmodel, detailed_data_tree);
202 rmodel.M = get_ei_size(detailed_data_tree);
203 rmodel.N = get_rb_size(detailed_data_tree);
205 fprintf(
'\n========================================================\n');
206 disp([
'Extended reduced bases to size N = ',...
207 num2str(rmodel.N),...
208 ', M = ', num2str(rmodel.M),
'.']);
210 disp([
'Computing ', num2str(size(cur_M_train.sample,1)),...
211 ' error indicators for basis extension:']);
212 [errs, max_err_seq, muind] = error_indicators(
this,...
214 detailed_data_tree,...
217 this.pretty_print_errs(
'Error sequence: ', errs);
219 if enable_validation(
this, dd_leaf)
221 disp(['Computing ', num2str(size(cur_M_validation.sample,1)),...
222 ' error indicators for validating basis extension:']);
223 val_errs = error_indicators(this,...
225 detailed_data_tree,...
226 cur_M_validation.sample,...
229 r_value = max(val_errs) / max(errs);
230 disp(['Max validation/training ratio: ', num2str(r_value)]);
231 append_field(dd_leaf, 'r_value_sequence', r_value);
234 Greedy.
Algorithm.push_back_extension_info(dd_leaf, errs, muind, max_err_seq);
235 if pre_check_for_end_meta(this, rmodel, detailed_data_tree, toc(alltimer)+basetoc);
236 set_field(dd_leaf, 'M_last_errors', errs);
240 mu = cur_M_train.sample(muind,:);
241 disp(['Choosing mu = ', num2str(mu(:)'), ...
242 ' for basis extension (error = ', num2str(max(errs)), ')']);
244 detailed_data_tree = basis_extension( this.detailed_extension,...
246 detailed_data_tree,...
249 % re-read dd_leaf, because basis extension algorithm is allowed to change the pointer
250 dd_leaf = get_active_leaf(detailed_data_tree, rmodel);
252 rmodel.N = get_rb_size(detailed_data_tree);
253 rmodel.M = get_ei_size(detailed_data_tree);
255 %! \todo check whether we should add another line
256 % errs = error_indicators(this, rmodel, this.detailed_gen, detailed_data);
257 % and do the validation afterwards!
259 append_field(dd_leaf, 'toc_value_sequence', toc(steptic));
261 checkpoint = checkpoint.store(rmodel, detailed_data_tree,...
262 'greedy_extension', ...
263 struct('toc', toc(alltimer)+basetoc));
267 detailed_data_tree = finalize(this.detailed_extension, rmodel, detailed_data_tree);
270 function [max_errs, max_err_sequence, max_mu_index] = error_indicators(this, rmodel, detailed_data_tree, parameter_set, reuse_reduced_data)
271 % function [max_errs, max_err_sequence, max_mu_index] = error_indicators(this, rmodel, detailed_data_tree, parameter_set, reuse_reduced_data)
272 % routine computing the error indicators
275 % detailed_data_tree: an
object of type Greedy.DataTree.Detailed.INode storing the
276 % reduced basis functions in the
leaf nodes.
277 % parameter_set: A matrix of dimension
278 % `\text{npar} \times \dim(\cal M)` containing in
279 % each row a parameter vector
for which the error
280 % indicator should be computed.
281 % reuse_reduced_data:
boolean flat indicating whether it is necessary to
282 % recompute the reduced data or whether it should
283 % still be valid since the last call to
284 % error_indicators().
287 % max_errs: For
transient problems
this returns the error sequence over
288 % time
for the parameter
'max_mu_index' with the worst error.
289 % max_err_sequence: a vector of the same length as the number of training
290 %
set parameters, contining the maximum error in the
291 % error sequence over time
for each of those
293 % max_mu_index: the parameter index of the parameter `\mu_{\max}` in
296 reuse_reduced_data =
false;
301 [max_errs, max_err_sequence, max_mu_index] = error_indicators(this.detailed_extension, rmodel, detailed_data_tree, parameter_set, reuse_reduced_data);
304 % initialization routine
for basis extension
306 % This method is run by the gen_detailed_data() method before the execution
307 % of the init_basis() methods and should is used for preparation purposes of
308 % - the training and validation parameter sample and
309 % - caching detailed simulations if necessary for error_indicators()
311 function prepare(this, rmodel, model_data)
312 dmodel = rmodel.detailed_model;
313 if init_required(this.M_train)
314 init_sample(this.M_train, rmodel);
317 if isa(model_data, 'Greedy.DataTree.Detailed.INode')
318 dd_leaf = get_active_leaf(model_data, rmodel);
323 validation_enabled = false;
324 if enable_validation(this, dd_leaf)
325 Msamples = get_field(dd_leaf, 'M_validation');
326 validation_enabled = true;
327 elseif isempty(dd_leaf) && ~isempty(this.M_validation) && ~isinf(this.stop_max_val_train_ratio)
328 Msamples = this.M_validation;
329 validation_enabled = true;
331 if validation_enabled && init_required(Msamples)
332 init_sample(Msamples, rmodel);
335 if this.detailed_extension.needs_preparation
336 if isa(model_data, 'Greedy.DataTree.Detailed.INode')
337 detailed_data = model_data;
338 model_data = detailed_data.model_data;
341 prepare(this.detailed_gen, dmodel, model_data, this.M_train.sample);
343 % prepare validation set if existent
344 if validation_enabled
345 prepare(this.detailed_gen, dmodel, model_data, Msamples.sample);
350 function detailed_data_tree = init_basis(this, rmodel, model_data)
351 % function detailed_data_tree = init_basis(this, model_data)
352 % @copybrief Greedy.Interface.init_basis()
354 % @copydoc Greedy.Interface.init_basis()
358 if isa(model_data, 'Greedy.DataTree.Detailed.INode')
359 this.M_train = get_field_on_active_child(model_data, 'M_train', rmodel);
361 M = this.M_train.sample;
362 detailed_data_tree = init_basis(this.detailed_extension, rmodel, model_data, M);
363 leaf_descrs = get_active_leaf_description(detailed_data_tree, rmodel);
364 for leaf_descr = leaf_descrs
365 detailed_data_leaf = get(detailed_data_tree, leaf_descr.basepath);
367 if isempty(get_field(detailed_data_leaf, Greedy.Algorithm.info_fields{1}, []))
368 set_fields(detailed_data_leaf, this, Greedy.Algorithm.info_fields);
375 methods (Access=private)
376 function breakcondition = pre_check_for_end_meta(this, rmodel, detailed_data_tree, time)
378 dd_leaf = get_active_leaf(detailed_data_tree, rmodel);
380 breakcondition = false;
381 max_err_sequence = get_field(dd_leaf, 'max_err_sequence');
382 if max_err_sequence(end) < get_field(dd_leaf, 'stop_epsilon')
383 set_stop_flag(dd_leaf, 'stopped_on_epsilon');
384 breakcondition = true;
385 elseif time > this.stop_timeout
386 set_stop_flag(dd_leaf, 'stopped_on_timeout');
387 breakcondition = true;
389 if enable_validation(this, dd_leaf)
390 val_seq_for_break = get_field(dd_leaf, 'val_train_ratio_seq_for_break');
391 r_values = get_field(dd_leaf, 'r_value_sequence');
392 minindex = min(val_seq_for_break, length(r_values))-1;
393 if all(r_values(end-minindex:end) > get_field(dd_leaf, 'stop_max_val_train_ratio'))
394 set_stop_flag(dd_leaf, 'stopped_on_max_val_train_ratio');
395 breakcondition = true;
399 breakcondition = pre_check_for_end(this.detailed_extension, rmodel, detailed_data_tree);
401 set_field(detailed_data_tree, 'elapsed_time', time);
406 methods (Access = private, Static)
407 function pretty_print_errs(title, errs)
408 nerrs = length(errs);
409 rows = floor(nerrs/6);
410 resherrs = reshape(errs(1:rows*6), rows, 6);
411 resterrs = errs(rows*6+1:end);
413 disp(
char(title, num2str(resherrs, format), num2str(resterrs(:)', format)));
416 function push_back_extension_info(detailed_data, errs, muind, max_err_seq)
418 append_field(detailed_data, 'max_err_sequence', max(errs));
419 append_field(detailed_data, 'errs_sequence', {errs});
420 append_field(detailed_data,
'mu_ind_sequence', muind);
421 cur_M_train = get_field(detailed_data,
'M_train');
422 append_field(detailed_data,
'mu_sequence', cur_M_train.sample(muind,:)
');
423 set_field(detailed_data, 'M_last_errors
', errs);
425 [dummy, maxt] = max(max_err_seq);
426 append_field(detailed_data, 'max_time_index_sequence
', maxt);