7 #ifndef PORTAGE_DRIVER_PARTS_H 8 #define PORTAGE_DRIVER_PARTS_H 15 #include <unordered_set> 19 #include <type_traits> 25 #define DEBUG_PART_BY_PART 0 34 using Wonton::Entity_kind;
35 using Wonton::Entity_type;
38 template<
class Mesh,
class State>
60 volumes_.resize(size_);
61 lookup_.reserve(size_);
64 for (
int i = 0; i < size_; ++i) {
65 auto const& s = cells[i];
82 const Mesh&
mesh()
const {
return mesh_; }
89 const State&
state()
const {
return state_; }
96 State&
state() {
return const_cast<State&
>(state_); }
103 std::vector<int>
const&
cells()
const {
return cells_; }
111 bool contains(
int id)
const {
return lookup_.count(
id) == 1; }
119 const int&
index(
int id)
const {
return index_.at(
id); }
126 const int&
size()
const {
return size_; }
136 if (cached_volumes) {
139 throw std::runtime_error(
"Error: cells volumes not yet computed");
149 template<Entity_type type = Entity_type::ALL>
151 std::vector<int> neigh, filtered;
153 mesh_.cell_get_node_adj_cells(entity, type, &neigh);
155 filtered.reserve(neigh.size());
156 for (
auto const& current : neigh) {
157 if (lookup_.count(current)) {
158 filtered.emplace_back(current);
170 if (cached_volumes) {
171 return std::accumulate(volumes_.begin(), volumes_.end(), 0.0);
173 throw std::runtime_error(
"Error: cells volumes not yet computed");
183 if (not cached_volumes) {
185 bool const use_masks = masks !=
nullptr;
189 auto const& i = index_[s];
190 auto const&
volume = mesh_.cell_volume(s);
194 cached_volumes =
true;
206 bool cached_volumes =
false;
211 std::vector<int> cells_ = {};
212 std::vector<double> volumes_ = {};
213 std::map<int, int> index_ = {};
214 std::unordered_set<int> lookup_ = {};
230 class SourceMesh,
class SourceState,
231 class TargetMesh = SourceMesh,
232 class TargetState = SourceState
236 using entity_weights_t = std::vector<Wonton::Weights_t>;
258 SourceMesh
const& source_mesh, SourceState& source_state,
259 TargetMesh
const& target_mesh, TargetState& target_state,
260 std::vector<int>
const& source_entities,
261 std::vector<int>
const& target_entities,
262 Wonton::Executor_type
const* executor
263 ) : source_(source_mesh, source_state, source_entities),
264 target_(target_mesh, target_state, target_entities)
266 #ifdef PORTAGE_ENABLE_MPI 267 auto mpiexecutor =
dynamic_cast<Wonton::MPIExecutor_type
const *
>(executor);
268 if (mpiexecutor && mpiexecutor->mpicomm != MPI_COMM_NULL) {
270 mycomm_ = mpiexecutor->mpicomm;
271 MPI_Comm_rank(mycomm_, &rank_);
272 MPI_Comm_size(mycomm_, &nprocs_);
281 int nb_masks = source_.mesh().num_owned_cells();
283 source_entities_masks_.resize(nb_masks, 1);
284 #ifdef PORTAGE_ENABLE_MPI 286 get_unique_entity_masks<Entity_kind::CELL, SourceMesh>(
287 source_.mesh(), &source_entities_masks_, mycomm_
292 intersection_volumes_.resize(target_.size());
334 double compute_intersect_volumes
337 auto const& target_entities = target_.
cells();
341 auto const& i = target_.index(t);
343 entity_weights_t
const moments = source_weights[t];
344 intersection_volumes_[i] = 0.;
345 for (
auto const& current : moments) {
347 if (source_.contains(current.entityID))
348 intersection_volumes_[i] += current.weights[0];
350 std::printf(
"\tmoments[target:%d][source:%d]: %f\n" 351 , t, current.entityID, current.weights[0]);
354 #if DEBUG_PART_BY_PART 355 std::printf(
"intersect_volume[%02d]: %.3f\n", t, intersection_volumes_[i]);
360 return std::accumulate(intersection_volumes_.begin(), intersection_volumes_.end(), 0.);
389 double source_volume = source_.compute_entity_volumes(source_entities_masks_.data());
390 double target_volume = target_.compute_entity_volumes();
391 double intersect_volume = compute_intersect_volumes(source_weights);
393 global_source_volume_ = source_volume;
394 global_target_volume_ = target_volume;
395 global_intersect_volume_ = intersect_volume;
397 #ifdef PORTAGE_ENABLE_MPI 399 MPI_Allreduce(&source_volume, &global_source_volume_, 1, MPI_DOUBLE, MPI_SUM, mycomm_);
400 MPI_Allreduce(&target_volume, &global_target_volume_, 1, MPI_DOUBLE, MPI_SUM, mycomm_);
401 MPI_Allreduce(&intersect_volume, &global_intersect_volume_, 1, MPI_DOUBLE, MPI_SUM, mycomm_);
404 std::printf(
"source volume: %.3f\n", global_source_volume_);
405 std::printf(
"target volume: %.3f\n", global_target_volume_);
406 std::printf(
"intersect volume: %.3f\n", global_intersect_volume_);
410 std::printf(
"source volume: %.3f\n", source_volume);
411 std::printf(
"target volume: %.3f\n", target_volume);
412 std::printf(
"intersect volume: %.3f\n", intersect_volume);
424 double const relative_voldiff_source =
425 std::abs(global_intersect_volume_ - global_source_volume_)
426 / global_source_volume_;
428 double const relative_voldiff_target =
429 std::abs(global_intersect_volume_ - global_target_volume_)
430 / global_target_volume_;
432 if (relative_voldiff_source > tolerance_) {
433 has_mismatch_ =
true;
436 std::fprintf(stderr,
"\n** MESH MISMATCH - some source cells ");
437 std::fprintf(stderr,
"are not fully covered by the target mesh\n");
441 if (relative_voldiff_target > tolerance_) {
442 has_mismatch_ =
true;
445 std::fprintf(stderr,
"\n** MESH MISMATCH - some target cells ");
446 std::fprintf(stderr,
"are not fully covered by the source mesh\n");
450 if (not has_mismatch_) {
451 is_mismatch_tested_ =
true;
457 relative_voldiff_ = relative_voldiff_source + relative_voldiff_target;
464 std::vector<int> empty_entities;
465 int const target_part_size = target_.size();
467 empty_entities.reserve(target_part_size);
468 is_cell_empty_.resize(target_part_size,
false);
470 for (
auto&& entity : target_.cells()) {
471 auto const& i = target_.index(entity);
472 if (std::abs(intersection_volumes_[i]) < epsilon_) {
473 empty_entities.emplace_back(entity);
474 is_cell_empty_[i] =
true;
478 int nb_empty = empty_entities.size();
479 int global_nb_empty = nb_empty;
481 #ifdef PORTAGE_ENABLE_MPI 484 MPI_Reduce(&nb_empty, &global_nb_empty, 1, MPI_INT, MPI_SUM, 0, mycomm_);
488 if (global_nb_empty > 0 and rank_ == 0) {
490 "One or more target cells are not covered by ANY source cells.\n" 491 "Will assign values based on their neighborhood\n" 496 layer_num_.resize(target_.size(), 0);
500 int old_nb_tagged = -1;
502 while (nb_tagged < nb_empty and nb_tagged > old_nb_tagged) {
503 old_nb_tagged = nb_tagged;
505 std::vector<int> current_layer_entities;
507 for (
auto&& entity : empty_entities) {
508 auto const& i = target_.index(entity);
510 if (layer_num_[i] == 0) {
512 auto neighbors = target_.get_neighbors(entity);
514 for (
auto&& neigh : neighbors) {
515 auto const& j = target_.index(neigh);
519 if (not is_cell_empty_[j] or layer_num_[j] != 0) {
520 current_layer_entities.push_back(entity);
528 for (
auto&& entity : current_layer_entities) {
529 auto const& t = target_.index(entity);
530 layer_num_[t] = nb_layers + 1;
532 nb_tagged += current_layer_entities.size();
534 empty_layers_.emplace_back(current_layer_entities);
539 is_mismatch_tested_ =
true;
540 return has_mismatch_;
582 std::string trg_var_name,
583 double global_lower_bound = -infinity_,
584 double global_upper_bound = infinity_,
585 double conservation_tol = tolerance_,
590 if (source_.state().field_type(Entity_kind::CELL, src_var_name) == Field_type::MESH_FIELD) {
591 return fix_mismatch_meshvar(src_var_name, trg_var_name,
592 global_lower_bound, global_upper_bound,
593 conservation_tol, maxiter,
594 partial_fixup_type, empty_fixup_type);
613 std::string
const & trg_var_name,
614 double global_lower_bound,
615 double global_upper_bound,
616 double conservation_tol,
622 static bool hit_lower_bound =
false;
623 static bool hit_higher_bound =
false;
626 auto const& source_entities = source_.cells();
627 auto const& target_entities = target_.cells();
628 auto const& source_state = source_.state();
629 auto& target_state =
const_cast<TargetState&
>(target_.state());
633 double const* source_data;
634 double* target_data =
nullptr;
635 source_state.mesh_get_data(Entity_kind::CELL, src_var_name, &source_data);
636 target_state.mesh_get_data(Entity_kind::CELL, trg_var_name, &target_data);
651 for (
auto&& entity : target_entities) {
652 auto const& t = target_.index(entity);
653 if (not is_cell_empty_[t]) {
655 #if DEBUG_PART_BY_PART 656 std::printf(
"fixing target_data[%d] with locally conservative fixup\n", entity);
657 std::printf(
"= before: %.3f", target_data[entity]);
660 auto const relative_voldiff =
661 std::abs(intersection_volumes_[t] - target_.volume(t)) / target_.volume(t);
663 if (relative_voldiff > tolerance_) {
664 target_data[entity] *= intersection_volumes_[t] / target_.volume(t);
666 #if DEBUG_PART_BY_PART 667 std::printf(
", after: %.3f\n", target_data[entity]);
683 int current_layer_number = 1;
685 for (
auto const& current_layer : empty_layers_) {
686 for (
auto&& entity : current_layer) {
688 double averaged_value = 0.;
691 target_.template get_neighbors<Entity_type::PARALLEL_OWNED>(entity);
693 for (
auto&& neigh : neighbors) {
694 auto const& i = target_.index(neigh);
695 if (layer_num_[i] < current_layer_number) {
696 averaged_value += target_data[neigh];
700 if (nb_extrapol > 0) {
701 averaged_value /= nb_extrapol;
703 #if DEBUG_PART_BY_PART 706 "No owned neighbors of empty entity to extrapolate data from\n" 710 target_data[entity] = averaged_value;
712 current_layer_number++;
727 double source_sum = 0.;
728 double target_sum = 0.;
730 for (
auto&& s : source_entities) {
731 auto const& i = source_.index(s);
732 source_sum += source_data[s] * source_.volume(i);
735 for (
auto&& t : target_entities) {
736 auto const& i = target_.index(t);
737 target_sum += target_data[t] * target_.volume(i);
740 double global_source_sum = source_sum;
741 double global_target_sum = target_sum;
743 #ifdef PORTAGE_ENABLE_MPI 746 &source_sum, &global_source_sum, 1, MPI_DOUBLE, MPI_SUM, mycomm_
750 &target_sum, &global_target_sum, 1, MPI_DOUBLE, MPI_SUM, mycomm_
755 double absolute_diff = global_target_sum - global_source_sum;
756 double relative_diff = absolute_diff / global_source_sum;
758 if (std::abs(relative_diff) < conservation_tol) {
771 double adj_target_volume;
772 double global_adj_target_volume;
773 double global_covered_target_volume;
776 double covered_target_volume = 0.;
777 for (
auto&& entity : target_entities) {
778 auto const& t = target_.index(entity);
779 if (not is_cell_empty_[t]) {
780 covered_target_volume += target_.volume(t);
783 global_covered_target_volume = covered_target_volume;
784 #ifdef PORTAGE_ENABLE_MPI 787 &covered_target_volume, &global_covered_target_volume,
788 1, MPI_DOUBLE, MPI_SUM, mycomm_
792 adj_target_volume = covered_target_volume;
793 global_adj_target_volume = global_covered_target_volume;
796 adj_target_volume = target_.total_volume();
797 global_adj_target_volume = global_target_volume_;
801 double udiff = absolute_diff / global_adj_target_volume;
804 while (std::abs(relative_diff) > conservation_tol and iter < maxiter) {
806 for (
auto&& entity : target_entities) {
807 auto const& t = target_.index(entity);
808 bool is_owned = target_.mesh().cell_get_type(entity) == Entity_type::PARALLEL_OWNED;
809 bool should_fix = (empty_fixup_type !=
LEAVE_EMPTY or not is_cell_empty_[t]);
811 if (is_owned and should_fix) {
812 if ((target_data[entity] - udiff) < global_lower_bound) {
816 target_data[entity] = global_lower_bound;
818 if (not hit_lower_bound) {
820 "Hit lower bound for cell %d (and maybe other ones) on rank %d\n",
823 hit_lower_bound =
true;
827 adj_target_volume -= target_.volume(t);
829 }
else if ((target_data[entity] - udiff) > global_upper_bound) {
833 target_data[entity] = global_upper_bound;
835 if (not hit_higher_bound) {
837 "Hit upper bound for cell %d (and maybe other ones) on rank %d\n",
840 hit_higher_bound =
true;
845 adj_target_volume -= target_.volume(t);
851 target_data[entity] -= udiff;
858 for (
auto&& entity : target_entities) {
859 auto const& t = target_.index(entity);
860 target_sum += target_.volume(t) * target_data[entity];
863 global_target_sum = target_sum;
864 #ifdef PORTAGE_ENABLE_MPI 867 &target_sum, &global_target_sum, 1, MPI_DOUBLE, MPI_SUM, mycomm_
878 absolute_diff = global_target_sum - global_source_sum;
879 global_adj_target_volume = adj_target_volume;
881 #ifdef PORTAGE_ENABLE_MPI 884 &adj_target_volume, &global_adj_target_volume,
885 1, MPI_DOUBLE, MPI_SUM, mycomm_
890 udiff = absolute_diff / global_adj_target_volume;
891 relative_diff = absolute_diff / global_source_sum;
895 global_adj_target_volume = (
896 empty_fixup_type ==
LEAVE_EMPTY ? global_covered_target_volume
897 : global_target_volume_
903 if (std::abs(relative_diff) > conservation_tol) {
906 "Redistribution not entirely successfully for variable %s\n" 907 "Relative conservation error is %f\n" 908 "Absolute conservation error is %f\n",
909 src_var_name.data(), relative_diff, absolute_diff
917 std::fprintf(stderr,
"Unknown Partial fixup type\n");
928 bool is_mismatch_tested_ =
false;
929 bool has_mismatch_ =
false;
932 static constexpr
double infinity_ = std::numeric_limits<double>::max();
934 static constexpr
double tolerance_ = 1.E2 * epsilon_;
936 std::vector<int> source_entities_masks_ = {};
937 std::vector<double> intersection_volumes_ = {};
940 double global_source_volume_ = 0.;
941 double global_target_volume_ = 0.;
942 double global_intersect_volume_ = 0.;
943 double relative_voldiff_ = 0.;
946 std::vector<int> layer_num_ = {};
947 std::vector<bool> is_cell_empty_ = {};
948 std::vector<std::vector<int>> empty_layers_ = {};
953 bool distributed_ =
false;
954 #ifdef PORTAGE_ENABLE_MPI 955 MPI_Comm mycomm_ = MPI_COMM_NULL;
960 #endif //PORTAGE_PARTS_H bool check_mismatch(Portage::vector< entity_weights_t > const &source_weights)
Check and fix source/target boundaries mismatch.
Definition: parts.h:383
bool contains(int id) const
Check if a given entity is in the part.
Definition: parts.h:111
State & state()
Get a normal reference to the underlying state.
Definition: parts.h:96
std::vector< int > const & cells() const
Get a reference to the cell list.
Definition: parts.h:103
PartPair(SourceMesh const &source_mesh, SourceState &source_state, TargetMesh const &target_mesh, TargetState &target_state, std::vector< int > const &source_entities, std::vector< int > const &target_entities, Wonton::Executor_type const *executor)
Construct a source-target mesh parts pair.
Definition: parts.h:257
std::vector< T > vector
Definition: portage.h:238
Part(Mesh const &mesh, State &state, std::vector< int > const &cells)
Construct a mesh part object.
Definition: parts.h:53
const int & index(int id) const
Retrieve relative index of given entity.
Definition: parts.h:119
std::vector< int > get_neighbors(int entity) const
Retrieve the neighbors of the given entity on mesh part.
Definition: parts.h:150
bool fix_mismatch_meshvar(std::string const &src_var_name, std::string const &trg_var_name, double global_lower_bound, double global_upper_bound, double conservation_tol, int maxiter, Partial_fixup_type partial_fixup_type, Empty_fixup_type empty_fixup_type) const
Repair a remapped mesh field to account for boundary mismatch.
Definition: parts.h:612
void for_each(InputIterator first, InputIterator last, UnaryFunction f)
Definition: portage.h:264
Definition: portage.h:114
bool has_mismatch() const
Do source and target meshes have a boundary mismatch?
Definition: parts.h:305
Partial_fixup_type
Fixup options for partially filled cells.
Definition: portage.h:114
Definition: portage.h:134
double const epsilon
Numerical tolerance.
Definition: weight.h:34
bool mismatch_tested() const
Is mismatch already tested?
Definition: parts.h:312
Manages source and target sub-meshes for part-by-part remap. It detects boundaries mismatch and provi...
Definition: parts.h:234
const SourcePart & source() const
Retrieve a pointer to source mesh part.
Definition: parts.h:319
std::vector< Wonton::Weights_t > entity_weights_t
Definition: parts.h:36
const State & state() const
Get a constant reference to the underlying state.
Definition: parts.h:89
Definition: portage.h:134
const Mesh & mesh() const
Get a constant reference to the underlying mesh.
Definition: parts.h:82
#define DEBUG_PART_BY_PART
Definition: parts.h:25
Definition: portage.h:114
Part()=default
Default constructor.
const int & size() const
Get part size.
Definition: parts.h:126
Definition: coredriver.h:42
bool fix_mismatch(std::string src_var_name, std::string trg_var_name, double global_lower_bound=-infinity_, double global_upper_bound=infinity_, double conservation_tol=tolerance_, int maxiter=5, Partial_fixup_type partial_fixup_type=SHIFTED_CONSERVATIVE, Empty_fixup_type empty_fixup_type=EXTRAPOLATE) const
Repair the remapped field to account for boundary mismatch.
Definition: parts.h:581
const TargetPart & target() const
Retrieve a pointer to target mesh part.
Definition: parts.h:326
const double & volume(int id) const
Get the volume of the given entity.
Definition: parts.h:134
Definition: portage.h:114
double total_volume() const
Compute the total volume of the part.
Definition: parts.h:169
double compute_entity_volumes(const int *masks=nullptr)
Compute and store volumes of each cell of the part.
Definition: parts.h:182
Empty_fixup_type
Fixup options for empty cells.
Definition: portage.h:134