Skip to content

Add list_feather_columns function in eager mode #404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 5, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions tensorflow_io/arrow/BUILD
Original file line number Diff line number Diff line change
@@ -7,21 +7,22 @@ load(
"tf_io_copts",
)

cc_binary(
name = "python/ops/_arrow_ops.so",
cc_library(
name = "arrow_ops",
srcs = [
"kernels/arrow_dataset_ops.cc",
"kernels/arrow_kernels.cc",
"kernels/arrow_kernels.h",
"kernels/arrow_stream_client.h",
"kernels/arrow_stream_client_unix.cc",
"kernels/arrow_util.cc",
"kernels/arrow_util.h",
"ops/dataset_ops.cc",
],
copts = tf_io_copts(),
linkshared = 1,
linkstatic = True,
deps = [
"//tensorflow_io/core:dataset_ops",
"@arrow",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
)
3 changes: 3 additions & 0 deletions tensorflow_io/arrow/__init__.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
@@ArrowDataset
@@ArrowFeatherDataset
@@ArrowStreamDataset
@@list_feather_columns
"""

from __future__ import absolute_import
@@ -26,13 +27,15 @@
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowDataset
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowFeatherDataset
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import ArrowStreamDataset
from tensorflow_io.arrow.python.ops.arrow_dataset_ops import list_feather_columns

from tensorflow.python.util.all_util import remove_undocumented

_allowed_symbols = [
"ArrowDataset",
"ArrowFeatherDataset",
"ArrowStreamDataset",
"list_feather_columns",
]

remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
159 changes: 159 additions & 0 deletions tensorflow_io/arrow/kernels/arrow_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow_io/arrow/kernels/arrow_kernels.h"
#include "arrow/io/api.h"
#include "arrow/ipc/feather.h"
#include "arrow/ipc/feather_generated.h"
#include "arrow/buffer.h"

namespace tensorflow {
namespace data {
namespace {

class ListFeatherColumnsOp : public OpKernel {
public:
explicit ListFeatherColumnsOp(OpKernelConstruction* context) : OpKernel(context) {
env_ = context->env();
}

void Compute(OpKernelContext* context) override {
const Tensor& filename_tensor = context->input(0);
const string filename = filename_tensor.scalar<string>()();

const Tensor& memory_tensor = context->input(1);
const string& memory = memory_tensor.scalar<string>()();
std::unique_ptr<SizedRandomAccessFile> file(new SizedRandomAccessFile(env_, filename, memory.data(), memory.size()));
uint64 size;
OP_REQUIRES_OK(context, file->GetFileSize(&size));

// FEA1.....[metadata][uint32 metadata_length]FEA1
static constexpr const char* kFeatherMagicBytes = "FEA1";

size_t header_length = strlen(kFeatherMagicBytes);
size_t footer_length = sizeof(uint32) + strlen(kFeatherMagicBytes);

string buffer;
buffer.resize(header_length > footer_length ? header_length : footer_length);

StringPiece result;

OP_REQUIRES_OK(context, file->Read(0, header_length, &result, &buffer[0]));
OP_REQUIRES(context, !memcmp(buffer.data(), kFeatherMagicBytes, header_length), errors::InvalidArgument("not a feather file"));

OP_REQUIRES_OK(context, file->Read(size - footer_length, footer_length, &result, &buffer[0]));
OP_REQUIRES(context, !memcmp(buffer.data() + sizeof(uint32), kFeatherMagicBytes, footer_length - sizeof(uint32)), errors::InvalidArgument("incomplete feather file"));

uint32 metadata_length = *reinterpret_cast<const uint32*>(buffer.data());

buffer.resize(metadata_length);

OP_REQUIRES_OK(context, file->Read(size - footer_length - metadata_length, metadata_length, &result, &buffer[0]));

const ::arrow::ipc::feather::fbs::CTable* table = ::arrow::ipc::feather::fbs::GetCTable(buffer.data());

OP_REQUIRES(context, (table->version() >= ::arrow::ipc::feather::kFeatherVersion), errors::InvalidArgument("feather file is old: ", table->version(), " vs. ", ::arrow::ipc::feather::kFeatherVersion));

std::vector<string> columns;
std::vector<string> dtypes;
std::vector<int64> counts;
columns.reserve(table->columns()->size());
dtypes.reserve(table->columns()->size());
counts.reserve(table->columns()->size());

for (int64 i = 0; i < table->columns()->size(); i++) {
DataType dtype = ::tensorflow::DataType::DT_INVALID;
switch (table->columns()->Get(i)->values()->type()) {
case ::arrow::ipc::feather::fbs::Type_BOOL:
dtype = ::tensorflow::DataType::DT_BOOL;
break;
case ::arrow::ipc::feather::fbs::Type_INT8:
dtype = ::tensorflow::DataType::DT_INT8;
break;
case ::arrow::ipc::feather::fbs::Type_INT16:
dtype = ::tensorflow::DataType::DT_INT16;
break;
case ::arrow::ipc::feather::fbs::Type_INT32:
dtype = ::tensorflow::DataType::DT_INT32;
break;
case ::arrow::ipc::feather::fbs::Type_INT64:
dtype = ::tensorflow::DataType::DT_INT64;
break;
case ::arrow::ipc::feather::fbs::Type_UINT8:
dtype = ::tensorflow::DataType::DT_UINT8;
break;
case ::arrow::ipc::feather::fbs::Type_UINT16:
dtype = ::tensorflow::DataType::DT_UINT16;
break;
case ::arrow::ipc::feather::fbs::Type_UINT32:
dtype = ::tensorflow::DataType::DT_UINT32;
break;
case ::arrow::ipc::feather::fbs::Type_UINT64:
dtype = ::tensorflow::DataType::DT_UINT64;
break;
case ::arrow::ipc::feather::fbs::Type_FLOAT:
dtype = ::tensorflow::DataType::DT_FLOAT;
break;
case ::arrow::ipc::feather::fbs::Type_DOUBLE:
dtype = ::tensorflow::DataType::DT_DOUBLE;
break;
case ::arrow::ipc::feather::fbs::Type_UTF8:
case ::arrow::ipc::feather::fbs::Type_BINARY:
case ::arrow::ipc::feather::fbs::Type_CATEGORY:
case ::arrow::ipc::feather::fbs::Type_TIMESTAMP:
case ::arrow::ipc::feather::fbs::Type_DATE:
case ::arrow::ipc::feather::fbs::Type_TIME:
// case ::arrow::ipc::feather::fbs::Type_LARGE_UTF8:
// case ::arrow::ipc::feather::fbs::Type_LARGE_BINARY:
default:
break;
}
columns.push_back(table->columns()->Get(i)->name()->str());
dtypes.push_back(::tensorflow::DataTypeString(dtype));
counts.push_back(table->num_rows());
}

TensorShape output_shape = filename_tensor.shape();
output_shape.AddDim(columns.size());

Tensor* columns_tensor;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &columns_tensor));
Tensor* dtypes_tensor;
OP_REQUIRES_OK(context, context->allocate_output(1, output_shape, &dtypes_tensor));

output_shape.AddDim(1);

Tensor* shapes_tensor;
OP_REQUIRES_OK(context, context->allocate_output(2, output_shape, &shapes_tensor));

for (size_t i = 0; i < columns.size(); i++) {
columns_tensor->flat<string>()(i) = columns[i];
dtypes_tensor->flat<string>()(i) = dtypes[i];
shapes_tensor->flat<int64>()(i) = counts[i];
}
}
private:
mutex mu_;
Env* env_ GUARDED_BY(mu_);
};

REGISTER_KERNEL_BUILDER(Name("ListFeatherColumns").Device(DEVICE_CPU),
ListFeatherColumnsOp);


} // namespace
} // namespace data
} // namespace tensorflow
80 changes: 80 additions & 0 deletions tensorflow_io/arrow/kernels/arrow_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "kernels/stream.h"
#include "arrow/io/api.h"
#include "arrow/buffer.h"

namespace tensorflow {
namespace data {

// NOTE: Both SizedRandomAccessFile and ArrowRandomAccessFile overlap
// with another PR. Will remove duplicate once PR merged

class ArrowRandomAccessFile : public ::arrow::io::RandomAccessFile {
public:
explicit ArrowRandomAccessFile(tensorflow::RandomAccessFile *file, int64 size)
: file_(file)
, size_(size) { }

~ArrowRandomAccessFile() {}
arrow::Status Close() override {
return arrow::Status::OK();
}
arrow::Status Tell(int64_t* position) const override {
return arrow::Status::NotImplemented("Tell");
}
arrow::Status Seek(int64_t position) override {
return arrow::Status::NotImplemented("Seek");
}
arrow::Status Read(int64_t nbytes, int64_t* bytes_read, void* out) override {
return arrow::Status::NotImplemented("Read (void*)");
}
arrow::Status Read(int64_t nbytes, std::shared_ptr<arrow::Buffer>* out) override {
return arrow::Status::NotImplemented("Read (Buffer*)");
}
arrow::Status GetSize(int64_t* size) override {
*size = size_;
return arrow::Status::OK();
}
bool supports_zero_copy() const override {
return false;
}
arrow::Status ReadAt(int64_t position, int64_t nbytes, int64_t* bytes_read, void* out) override {
StringPiece result;
Status status = file_->Read(position, nbytes, &result, (char*)out);
if (!(status.ok() || errors::IsOutOfRange(status))) {
return arrow::Status::IOError(status.error_message());
}
*bytes_read = result.size();
return arrow::Status::OK();
}
arrow::Status ReadAt(int64_t position, int64_t nbytes, std::shared_ptr<arrow::Buffer>* out) override {
string buffer;
buffer.resize(nbytes);
StringPiece result;
Status status = file_->Read(position, nbytes, &result, (char*)(&buffer[0]));
if (!(status.ok() || errors::IsOutOfRange(status))) {
return arrow::Status::IOError(status.error_message());
}
buffer.resize(result.size());
return arrow::Buffer::FromString(buffer, out);
}
private:
tensorflow::RandomAccessFile* file_;
int64 size_;
};
} // namespace data
} // namespace tensorflow
14 changes: 14 additions & 0 deletions tensorflow_io/arrow/ops/dataset_ops.cc
Original file line number Diff line number Diff line change
@@ -67,4 +67,18 @@ Creates a dataset that connects to a host serving Arrow RecordBatches in stream
endpoints: One or more host addresses that are serving an Arrow stream.
)doc");


REGISTER_OP("ListFeatherColumns")
.Input("filename: string")
.Input("memory: string")
.Output("columns: string")
.Output("dtypes: string")
.Output("shapes: int64")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->MakeShape({c->UnknownDim()}));
c->set_output(1, c->MakeShape({c->UnknownDim()}));
c->set_output(2, c->MakeShape({c->UnknownDim(), c->UnknownDim()}));
return Status::OK();
});

} // namespace tensorflow
21 changes: 16 additions & 5 deletions tensorflow_io/arrow/python/ops/arrow_dataset_ops.py
Original file line number Diff line number Diff line change
@@ -30,8 +30,7 @@
from tensorflow.compat.v2 import data
from tensorflow.python.data.ops.dataset_ops import flat_structure
from tensorflow.python.data.util import structure as structure_lib
from tensorflow_io import _load_library
arrow_ops = _load_library('_arrow_ops.so')
from tensorflow_io.core.python.ops import core_ops

if hasattr(tf, "nest"):
from tensorflow import nest # pylint: disable=ungrouped-imports
@@ -183,7 +182,7 @@ def __init__(self,
"auto" (size to number of records in Arrow record batch)
"""
super(ArrowDataset, self).__init__(
partial(arrow_ops.arrow_dataset, serialized_batches),
partial(core_ops.arrow_dataset, serialized_batches),
columns,
output_types,
output_shapes,
@@ -316,7 +315,7 @@ def __init__(self,
dtype=dtypes.string,
name="filenames")
super(ArrowFeatherDataset, self).__init__(
partial(arrow_ops.arrow_feather_dataset, filenames),
partial(core_ops.arrow_feather_dataset, filenames),
columns,
output_types,
output_shapes,
@@ -401,7 +400,7 @@ def __init__(self,
dtype=dtypes.string,
name="endpoints")
super(ArrowStreamDataset, self).__init__(
partial(arrow_ops.arrow_stream_dataset, endpoints),
partial(core_ops.arrow_stream_dataset, endpoints),
columns,
output_types,
output_shapes,
@@ -594,3 +593,15 @@ def gen_record_batches():
batch_size=batch_size,
batch_mode='keep_remainder',
record_batch_iter_factory=gen_record_batches)

def list_feather_columns(filename, **kwargs):
"""list_feather_columns"""
if not tf.executing_eagerly():
raise NotImplementedError("list_feather_columns only support eager mode")
memory = kwargs.get("memory", "")
columns, dtypes_, shapes = core_ops.list_feather_columns(
filename, memory=memory)
entries = zip(tf.unstack(columns), tf.unstack(dtypes_), tf.unstack(shapes))
return dict([(column.numpy().decode(), tf.TensorSpec(
shape.numpy(), dtype.numpy().decode(), column.numpy().decode())) for (
column, dtype, shape) in entries])
1 change: 1 addition & 0 deletions tensorflow_io/core/BUILD
Original file line number Diff line number Diff line change
@@ -125,6 +125,7 @@ cc_binary(
linkshared = 1,
deps = [
":core_ops",
"//tensorflow_io/arrow:arrow_ops",
"//tensorflow_io/audio:audio_ops",
"//tensorflow_io/avro:avro_ops",
"//tensorflow_io/azure:azfs_ops",
2 changes: 1 addition & 1 deletion tensorflow_io/parquet/BUILD
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@ cc_library(
copts = tf_io_copts(),
linkstatic = True,
deps = [
"//tensorflow_io/arrow:arrow_ops",
"//tensorflow_io/core:dataset_ops",
"@arrow",
],
)
Loading