42 template <
typename XType,
typename YType,
typename WType>
61 inline bool empty()
const {
return features.empty(); }
63 inline int size()
const {
return (
int)features.size(); }
65 inline bool hasWeights()
const {
return !weights.empty(); }
67 inline bool hasIndexes()
const {
return !indexes.empty(); }
70 inline bool valid()
const;
73 int count(
const YType& label)
const;
75 void reserve(
int reserveSize);
77 inline int numFeatures()
const;
79 inline YType minTarget()
const;
81 inline YType maxTarget()
const;
85 static int size(
const char *filename);
87 static int numFeatures(
const char *filename);
89 static bool hasWeights(
const char *filename);
91 static bool hasIndexes(
const char *filename);
97 int write(
const char *filename,
bool bAppend =
false)
const;
99 int write(
const char *filename,
int startIndx,
int endIndx,
bool bAppend =
false)
const;
101 int read(
const char *filename,
bool bAppend =
false);
103 int read(
const char *filename,
int startIndx,
int endIndx,
bool bAppend =
false);
109 int append(
const vector<XType>& x,
const YType& y);
111 int append(
const vector<XType>& x,
const YType& y,
const WType& w);
113 int append(
const vector<XType>& x,
const YType& y,
const WType& w,
int indx);
118 int subSample(
int sampleRate,
bool bBalanced =
false);
130 template <
typename XType,
typename YType,
typename WType>
136 template <
typename XType,
typename YType,
typename WType>
138 features(d.features), targets(d.targets), weights(d.weights), indexes(d.indexes) {
142 template <
typename XType,
typename YType,
typename WType>
148 template <
typename XType,
typename YType,
typename WType>
154 template <
typename XType,
typename YType,
typename WType>
158 for (
typename vector<vector<XType> >::const_iterator it =
features.begin();
160 if (it->size() != nFeatures)
return false;
168 template <
typename XType,
typename YType,
typename WType>
172 for (
typename vector<YType>::const_iterator it =
targets.begin();
181 template <
typename XType,
typename YType,
typename WType>
190 template <
typename XType,
typename YType,
typename WType>
196 template <
typename XType,
typename YType,
typename WType>
199 return targets.empty() ? (YType)0 :
203 template <
typename XType,
typename YType,
typename WType>
206 return targets.empty() ? (YType)0 :
211 template <
typename XType,
typename YType,
typename WType>
214 DRWN_ASSERT(filename != NULL);
217 ifstream ifs(filename, ifstream::in | ifstream::binary);
219 DRWN_LOG_ERROR(
"could not find file " << filename);
224 DRWN_LOG_WARNING(
"empty file " << filename);
229 ifs.read((
char *)&flags,
sizeof(
unsigned));
230 DRWN_ASSERT_MSG((flags & 0xffff0000) == 0x00010000,
"unrecognized file version");
233 ifs.read((
char *)&nFeatures,
sizeof(
int));
235 ifs.seekg(0, ios::end);
236 int len = (int)ifs.tellg() - 2 *
sizeof(int);
240 int bytesPerRecord =
sizeof(YType) + nFeatures *
sizeof(XType);
241 if ((flags & 0x00000001) == 0x00000001) bytesPerRecord +=
sizeof(WType);
242 if ((flags & 0x00000002) == 0x00000002) bytesPerRecord +=
sizeof(
int);
244 DRWN_ASSERT_MSG(len % bytesPerRecord == 0,
"corrupt file " << filename
245 <<
" (len: " << len <<
", bytes/record = " << bytesPerRecord <<
")");
246 return (
int)(len / bytesPerRecord);
249 template <
typename XType,
typename YType,
typename WType>
252 DRWN_ASSERT(filename != NULL);
255 ifstream ifs(filename, ifstream::in | ifstream::binary);
257 DRWN_LOG_ERROR(
"could not open file " << filename);
262 DRWN_LOG_WARNING(
"empty file " << filename);
267 ifs.read((
char *)&flags,
sizeof(
unsigned));
268 DRWN_ASSERT_MSG((flags & 0xffff0000) == 0x00010000,
"unrecognized file version");
271 ifs.read((
char *)&nFeatures,
sizeof(
int));
277 template <
typename XType,
typename YType,
typename WType>
280 DRWN_ASSERT(filename != NULL);
283 ifstream ifs(filename, ifstream::in | ifstream::binary);
285 DRWN_LOG_ERROR(
"could not open file " << filename);
290 DRWN_LOG_WARNING(
"empty file " << filename);
295 ifs.read((
char *)&flags,
sizeof(
unsigned));
296 DRWN_ASSERT_MSG((flags & 0xffff0000) == 0x00010000,
"unrecognized file version");
298 return ((flags & 0x00000001) == 0x00000001);
301 template <
typename XType,
typename YType,
typename WType>
304 DRWN_ASSERT(filename != NULL);
307 ifstream ifs(filename, ifstream::in | ifstream::binary);
309 DRWN_LOG_ERROR(
"could not open file " << filename);
314 DRWN_LOG_WARNING(
"empty file " << filename);
319 ifs.read((
char *)&flags,
sizeof(
unsigned));
320 DRWN_ASSERT_MSG((flags & 0xffff0000) == 0x00010000,
"unrecognized file version");
322 return ((flags & 0x00000002) == 0x00000002);
327 template <
typename XType,
typename YType,
typename WType>
336 template <
typename XType,
typename YType,
typename WType>
339 if (this->
empty())
return 0;
340 return write(filename, 0, this->
size() - 1, bAppend);
343 template <
typename XType,
typename YType,
typename WType>
346 DRWN_ASSERT(filename != NULL);
347 DRWN_ASSERT(this->
valid());
348 DRWN_ASSERT_MSG((startIndx >= 0) && (endIndx < this->
size()) && (startIndx <= endIndx),
349 "startIndx = " << startIndx <<
", endIndx = " << endIndx <<
", size() = " << this->
size());
352 unsigned flags = 0x00010000;
362 ofs.open(filename, ios::in | ios::out | ios::binary);
363 ofs.seekg(0, ios::beg);
364 ofs.read((
char *)&fileFlags,
sizeof(
unsigned));
365 DRWN_ASSERT(fileFlags == flags);
366 ofs.read((
char *)&fileNumFeatures,
sizeof(
int));
367 DRWN_ASSERT(fileNumFeatures == nFeatures);
369 ofs.seekp(0, ios::end);
371 ofs.open(filename, ios::out | ios::binary);
372 ofs.write((
char *)&flags,
sizeof(
unsigned));
373 ofs.write((
char *)&nFeatures,
sizeof(
int));
375 DRWN_ASSERT_MSG(!ofs.fail(), filename);
378 for (
int i = startIndx; i <= endIndx ; i++) {
379 ofs.write((
char *)&
targets[i],
sizeof(YType));
380 ofs.write((
char *)&
features[i][0], nFeatures *
sizeof(XType));
382 ofs.write((
char *)&
weights[i],
sizeof(WType));
385 ofs.write((
char *)&
indexes[i],
sizeof(
int));
389 int len = (int)ofs.tellp() - 2 *
sizeof(int);
392 int bytesPerRecord =
sizeof(YType) + nFeatures *
sizeof(XType);
393 if (
hasWeights()) bytesPerRecord +=
sizeof(WType);
394 if (
hasIndexes()) bytesPerRecord +=
sizeof(
int);
396 DRWN_ASSERT_MSG(len % bytesPerRecord == 0,
"corrupt file " << filename
397 <<
" (len: " << len <<
", bytes/record = " << bytesPerRecord <<
")");
398 return (
int)(len / bytesPerRecord);
401 template <
typename XType,
typename YType,
typename WType>
404 return read(filename, 0, numeric_limits<int>::max(), bAppend);
407 template <
typename XType,
typename YType,
typename WType>
410 DRWN_ASSERT(filename != NULL);
411 DRWN_ASSERT((startIndx >= 0) && (endIndx >= startIndx));
412 if (!bAppend)
clear();
415 ifstream ifs(filename, ifstream::in | ifstream::binary);
417 DRWN_LOG_ERROR(
"could not find file " << filename);
422 DRWN_LOG_WARNING(
"empty file " << filename);
427 ifs.read((
char *)&flags,
sizeof(
unsigned));
428 DRWN_ASSERT_MSG((flags & 0xffff0000) == 0x00010000,
"unrecognized file version: " << flags);
429 DRWN_ASSERT(
empty() || ((flags & 0x00000001) == (
hasWeights() ? 0x00000001 : 0x00000000)));
430 DRWN_ASSERT(
empty() || ((flags & 0x00000002) == (
hasIndexes() ? 0x00000002 : 0x00000000)));
433 ifs.read((
char *)&nFeatures,
sizeof(
int));
436 int bytesPerRecord =
sizeof(YType) + nFeatures *
sizeof(XType);
437 if ((flags & 0x00000001) == 0x00000001) bytesPerRecord +=
sizeof(WType);
438 if ((flags & 0x00000002) == 0x00000002) bytesPerRecord +=
sizeof(
int);
441 ifs.seekg(startIndx * bytesPerRecord, ios::cur);
444 DRWN_LOG_WARNING(
"less than " << startIndx <<
" record in file " << filename);
450 vector<XType> x(nFeatures);
454 int recordCount = startIndx;
455 while (recordCount <= endIndx) {
456 ifs.read((
char *)&y,
sizeof(YType));
457 ifs.read((
char *)&x[0], nFeatures *
sizeof(XType));
458 if ((flags & 0x00000001) == 0x00000001) {
459 ifs.read((
char *)&w,
sizeof(WType));
461 if ((flags & 0x00000002) == 0x00000002) {
462 ifs.read((
char *)&index,
sizeof(
int));
465 if (ifs.fail())
break;
468 if ((flags & 0x00000001) == 0x00000001) {
471 if ((flags & 0x00000002) == 0x00000002) {
485 template <
typename XType,
typename YType,
typename WType>
513 template <
typename XType,
typename YType,
typename WType>
527 template <
typename XType,
typename YType,
typename WType>
543 template <
typename XType,
typename YType,
typename WType>
559 template <
typename XType,
typename YType,
typename WType>
562 DRWN_ASSERT_MSG(sampleRate > 0,
"sampleRate must be greater than one");
567 vector<int> indx = drwn::randomPermutation(
features.size());
570 map<YType, vector<int> > stratified;
571 for (
size_t i = 0; i < indx.size(); i++) {
572 typename map<YType, vector<int> >::iterator it = stratified.find(
targets[indx[i]]);
573 if (it == stratified.end()) {
574 stratified.insert(make_pair(
targets[indx[i]], vector<int>(1, indx[i])));
576 it->second.push_back(indx[i]);
580 size_t maxSamplesPerTarget = 1;
581 for (
typename map<YType, vector<int> >::const_iterator it = stratified.begin();
582 it != stratified.end(); ++it) {
583 maxSamplesPerTarget = std::max(maxSamplesPerTarget, it->second.size());
585 maxSamplesPerTarget = (maxSamplesPerTarget + sampleRate - 1) / sampleRate;
589 for (
typename map<YType, vector<int> >::iterator it = stratified.begin();
590 it != stratified.end(); ++it) {
591 if (it->second.size() > maxSamplesPerTarget) {
592 it->second.resize(maxSamplesPerTarget);
594 indx.insert(indx.end(), it->second.begin(), it->second.end());
599 indx.resize((
features.size() + sampleRate - 1) / sampleRate);
603 vector<vector<XType> > nfeatures(indx.size());
604 vector<YType> ntargets(nfeatures.size());
605 vector<WType> nweights(
hasWeights() ? nfeatures.size() : 0);
606 vector<int> nindexes(
hasIndexes() ? nfeatures.size() : 0);
608 for (
size_t i = 0; i < nfeatures.size(); i++) {
609 std::swap(nfeatures[i],
features[indx[i]]);
610 std::swap(ntargets[i],
targets[indx[i]]);
611 if (!nweights.empty()) {
612 std::swap(nweights[i],
weights[indx[i]]);
614 if (!nindexes.empty()) {
615 nindexes[i] =
indexes[indx[i]];
624 return (
int)indx.size();
vector< WType > weights
weights (optional)
Definition: drwnDataset.h:47
int read(const char *filename, bool bAppend=false)
reads a dataset from disk (optionally appending to the current dataset)
Definition: drwnDataset.h:402
YType minTarget() const
returns the minimum target value in the dataset
Definition: drwnDataset.h:197
drwnDataset< double, double, double > drwnRegressionDataset
standard dataset for supervised regression algorithms
Definition: drwnDataset.h:126
int write(const char *filename, bool bAppend=false) const
writes the current dataset to disk (optionally appending to an existing dataset)
Definition: drwnDataset.h:337
void reserve(int reserveSize)
pre-allocate memory for storing samples (feature vectors and targets)
Definition: drwnDataset.h:182
void clear()
clears all data in the dataset
Definition: drwnDataset.h:328
vector< YType > targets
target labels
Definition: drwnDataset.h:46
vector< vector< XType > > features
feature vectors
Definition: drwnDataset.h:45
bool drwnFileExists(const char *filename)
checks if a file exists
Definition: drwnFileUtils.cpp:323
YType maxTarget() const
returns the maximum target value in the dataset
Definition: drwnDataset.h:204
int count(const YType &label) const
returns the number of samples with a given target label
Definition: drwnDataset.h:169
drwnDataset()
default constructor
Definition: drwnDataset.h:131
int append(const drwnDataset< XType, YType, WType > &d)
appends the samples from another dataset to this dataset
Definition: drwnDataset.h:486
vector< int > indexes
external indices (optional)
Definition: drwnDataset.h:48
int size() const
return the number of samples in the dataset
Definition: drwnDataset.h:63
int numFeatures() const
returns the number of features in the feature vector
Definition: drwnDataset.h:191
drwnDataset< double, int, double > drwnClassifierDataset
standard dataset for supervised classification algorithms
Definition: drwnDataset.h:124
bool empty() const
return true if the dataset is empty
Definition: drwnDataset.h:61
int subSample(int sampleRate, bool bBalanced=false)
subsample a dataset (balanced is only valid for discrete target types) if bBalanced is true then samp...
Definition: drwnDataset.h:560
bool valid() const
return true if the dataset is valid (e.g., number of targets equals number of feature vectors) ...
Definition: drwnDataset.h:155
Implements a cacheable dataset containing feature vectors, labels and optional weights.
Definition: drwnDataset.h:43
bool hasIndexes() const
return true if the dataset has external indices associated with each sample
Definition: drwnDataset.h:67
bool hasWeights() const
return true if the dataset contains weighted samples
Definition: drwnDataset.h:65