swarm_state.h
Go to the documentation of this file.
1 /*
2 This file is part of the Ristra portage project.
3 Please see the license file at the root of this repository, or at:
4  https://github.com/laristra/portage/blob/master/LICENSE
5 */
6 #ifndef SWARM_STATE_H_INC_
7 #define SWARM_STATE_H_INC_
8 
9 #include <vector>
10 #include <map>
11 #include <memory>
12 #include <string>
13 #include <cassert>
14 
15 #include "portage/swarm/swarm.h"
17 
18 namespace Portage { namespace Meshfree {
19 
27 template<int dim>
28 class SwarmState {
29 public:
34  SwarmState() = default;
35 
41  explicit SwarmState(Swarm<dim> const& swarm)
42  : num_local_points_(swarm.num_owned_particles())
43  {}
44 
50  explicit SwarmState(int size) : num_local_points_(size) {}
51 
62  template<typename State>
63  SwarmState(State const& state, Wonton::Entity_kind kind) {
64  // create return value
65  int num_entities = 0;
66  auto const field_names = state.names();
67 
68  // retrieve number of entities
69  for (auto&& name : field_names) {
70  if (state.get_entity(name) == kind) {
71  num_entities = state.get_data_size(kind, name);
72  break;
73  }
74  }
75  // set number of particles
76  num_local_points_ = num_entities;
77 
78  // copy data
79  for (auto&& name : field_names) {
80  if (state.get_entity(name) == kind) {
81  // retrieve field from mesh state
82  const double* values;
83  state.mesh_get_data(kind, name, &values);
84  assert(values != nullptr);
85 
86  // perform a deep copy then
87  auto& field = fields_dbl_[name];
88  field.resize(num_entities);
89  std::copy(values, values + num_entities, field.begin());
90  }
91  }
92  }
93 
108  template<typename State>
109  SwarmState(std::vector<State*> const& states, Wonton::Entity_kind kind) {
110 
111  auto const& names = states[0]->names();
112  int const num_fields = names.size();
113  int const num_states = states.size();
114 
115  // check all fields in all states
116  for (auto&& state : states) {
117  auto current = state->names();
118  if (current.size() == unsigned(num_fields)) {
119  for (int i = 0; i < num_fields; ++i)
120  if (names[i] != current[i])
121  throw std::runtime_error("field names do not match");
122  } else
123  throw std::runtime_error("field names do not match");
124  }
125 
126  // get sizes of data fields that match entity on each wrapper
127  std::vector<int> sizes[num_fields];
128 
129  for (int i = 0; i < num_fields; ++i) {
130  auto const& name = names[i];
131  for (int j = 0; j < num_states; ++j) {
132  auto const& state = *(states[j]);
133  if (state.get_entity(name) == kind) {
134  sizes[i].emplace_back(state.get_data_size(kind, name));
135  }
136  }
137  }
138 
139  // ensure sizes are same across names for each wrapper
140  for (int i = 0; i < num_fields; ++i) {
141  for (int j = 0; j < num_states; ++j) {
142  assert(sizes[0][j] > 0);
143  if (sizes[i][j] != sizes[0][j])
144  throw std::runtime_error("field sizes do not match");
145  }
146  }
147 
148  // compute particle offsets per state
149  int num_entities = 0;
150  int offset[num_states];
151  for (int i = 0; i < num_states; i++) {
152  offset[i] = num_entities;
153  num_entities += sizes[0][i];
154  }
155 
156  num_local_points_ = num_entities;
157 
158  // copy data
159  for (auto&& name : names) {
160  // resize particle field array
161  auto& field = fields_dbl_[name];
162  field.resize(num_entities);
163 
164  for (int i = 0; i < num_states; ++i) {
165  // retrieve field from mesh state
166  double* values = nullptr;
167  states[i]->mesh_get_data(kind, name, &values);
168  assert(values != nullptr);
169  // perform a deep copy then
170  std::copy(values, values + sizes[0][i], field.begin() + offset[i]);
171  }
172  }
173  }
174 
179  ~SwarmState() = default;
180 
188  template<typename T = double>
189  void add_field(std::string name, Portage::vector<T> const& value) {
190 
191  static_assert(std::is_arithmetic<T>::value, "only numeric fields");
192  // sizes should match
193  assert(value.size() == unsigned(num_local_points_));
194 
195  if (std::is_integral<T>::value) {
196  auto& field = fields_int_[name];
197  field.resize(value.size());
198  std::copy(value.begin(), value.end(), field.begin());
199  } else {
200  auto& field = fields_dbl_[name];
201  field.resize(value.size());
202  std::copy(value.begin(), value.end(), field.begin());
203  }
204  }
205 
206 #ifdef PORTAGE_ENABLE_THRUST
207 
214  template<typename T = double>
215  void add_field(std::string name, std::vector<T> const& value) {
216 
217  static_assert(std::is_arithmetic<T>::value, "only numeric fields");
218  // sizes should match
219  assert(value.size() == unsigned(num_local_points_));
220 
221  if (std::is_integral<T>::value) {
222  auto& field = fields_int_[name];
223  field.resize(value.size());
224  std::copy(value.begin(), value.end(), field.begin());
225  } else {
226  auto& field = fields_dbl_[name];
227  field.resize(value.size());
228  std::copy(value.begin(), value.end(), field.begin());
229  }
230  }
231 #endif
232 
240  template<typename T = double>
241  void add_field(std::string name, const T* const value) {
242 
243  static_assert(std::is_arithmetic<T>::value, "only numeric fields");
244  // assume value is of the right length - we can't check it
245  assert(value != nullptr);
246  // do a deep copy on related array
247  if (std::is_integral<T>::value) {
248  auto& field = fields_int_[name];
249  field.resize(num_local_points_);
250  std::copy(value, value + num_local_points_, field.begin());
251  } else {
252  auto& field = fields_dbl_[name];
253  field.resize(num_local_points_);
254  std::copy(value, value + num_local_points_, field.begin());
255  }
256  }
257 
265  template<typename T=double>
266  void add_field(std::string name, T value) {
267 
268  static_assert(std::is_arithmetic<T>::value, "only numeric fields");
269 
270  if (std::is_integral<T>::value) {
271  auto& field = fields_int_[name];
272  field.resize(num_local_points_, value);
273  } else {
274  auto& field = fields_dbl_[name];
275  field.resize(num_local_points_, value);
276  }
277  }
278 
289  Portage::vector<int>& get_field_int(std::string name) const {
290  assert(fields_int_.count(name));
291  using T = Portage::vector<int>;
292  return const_cast<T&>(fields_int_.at(name));
293  }
294 
305  Portage::vector<double>& get_field_dbl(std::string name) const {
306  assert(fields_dbl_.count(name));
307  using T = Portage::vector<double>;
308  return const_cast<T&>(fields_dbl_.at(name));
309  }
310 
316  Portage::vector<double>& get_field(std::string name) const {
317  return get_field_dbl(name);
318  }
319 
326  template<typename T = double>
327  void copy_field(std::string name, T* value) {
328 
329  static_assert(std::is_arithmetic<T>::value, "only numeric fields");
330  assert(value != nullptr);
331 
332  if (std::is_integral<T>::value) {
333  assert(fields_int_.count(name));
334  auto& field = fields_int_[name];
335  std::copy(field.begin(), field.end(), value);
336  } else {
337  assert(fields_dbl_.count(name));
338  auto& field = fields_dbl_[name];
339  std::copy(field.begin(), field.end(), value);
340  }
341  }
342 
348  int get_size() { return num_local_points_; }
349 
356  template<typename T = double>
357  std::vector<std::string> get_field_names() {
358 
359  static_assert(std::is_arithmetic<T>::value, "only numeric fields");
360 
361  std::vector<std::string> list;
362  if (std::is_integral<T>::value) {
363  for (auto&& field : fields_int_)
364  list.emplace_back(field.first);
365  } else {
366  for (auto&& field : fields_dbl_)
367  list.emplace_back(field.first);
368  }
369 
370  return list;
371  }
372 
380  template<typename T>
381  void extend_field(std::string name, Portage::vector<T> const& values) {
382 
383  static_assert(std::is_arithmetic<T>::value, "only numeric fields");
384 
385  if (std::is_integral<T>::value) {
386  assert(fields_int_.count(name));
387  auto& field = fields_int_[name];
388  field.insert(field.begin(), values.begin(), values.end());
389  } else {
390  assert(fields_dbl_.count(name));
391  auto& field = fields_dbl_[name];
392  field.insert(field.end(), values.begin(), values.end());
393  }
394  }
395 
396  private:
398  int num_local_points_ = 0;
399 
401  std::map<std::string, Portage::vector<int>> fields_int_ {};
402  std::map<std::string, Portage::vector<double>> fields_dbl_ {};
403 };
404 
405 }} //namespace Portage::MeshFree
406 
407 #endif // SWARM_STATE_H_INC_
void copy_field(std::string name, T *value)
Retrieve the specified field.
Definition: swarm_state.h:327
std::vector< T > vector
Definition: portage.h:238
int get_size()
Get number of particles.
Definition: swarm_state.h:348
void add_field(std::string name, const T *const value)
Set a field on the swarm.
Definition: swarm_state.h:241
Particle field state class.
Definition: swarm_state.h:28
SwarmState(int size)
Initialize with a field size.
Definition: swarm_state.h:50
SwarmState(Swarm< dim > const &swarm)
Initialize from a reference swarm.
Definition: swarm_state.h:41
~SwarmState()=default
Destructor.
Portage::vector< double > & get_field(std::string name) const
Definition: swarm_state.h:316
Definition: coredriver.h:42
SwarmState()=default
Create an empty state.
An effective "mesh" class for a collection disconnected points (particles).
Definition: swarm.h:35
std::vector< std::string > get_field_names()
Retrieve the list of field names.
Definition: swarm_state.h:357
void extend_field(std::string name, Portage::vector< T > const &values)
Extend the field with the given subfield.
Definition: swarm_state.h:381
Portage::vector< int > & get_field_int(std::string name) const
Retrieve a specified integer field.
Definition: swarm_state.h:289
void add_field(std::string name, T value)
Set an empty field on the swarm.
Definition: swarm_state.h:266
Portage::vector< double > & get_field_dbl(std::string name) const
Retrieve a specified real field.
Definition: swarm_state.h:305
SwarmState(std::vector< State *> const &states, Wonton::Entity_kind kind)
Create the swarm state from set of mesh states wrappers.
Definition: swarm_state.h:109
void add_field(std::string name, Portage::vector< T > const &value)
Set a field on the swarm.
Definition: swarm_state.h:189
SwarmState(State const &state, Wonton::Entity_kind kind)
Create the swarm state from a given mesh state wrapper.
Definition: swarm_state.h:63