Skip to content

Commit 6ccad58

Browse files
q10facebook-github-bot
authored andcommitted
Add abstractions for writing out data (flesh out D71147675, pt 1) (pytorch#947)
Summary: Pull Request resolved: facebookresearch/FBGEMM#947 X-link: pytorch#3856 - Add abstractions for writing out data to FileStorage. This folds in common code from D71147675, D69883868, and D64932387 Reviewed By: spcyppt Differential Revision: D71350076 fbshipit-source-id: 6f0e01d15e5afa79e62db9c4af46c7fab2980975
1 parent f48b70e commit 6ccad58

File tree

3 files changed

+321
-1
lines changed

3 files changed

+321
-1
lines changed

fbgemm_gpu/fbgemm_gpu/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77

88
# pyre-unsafe
99

10-
from fbgemm_gpu.utils.torch_library import TorchLibraryFragment # noqa F401
10+
from .filestore import FileStore # noqa F401
11+
from .torch_library import TorchLibraryFragment # noqa F401
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
# pyre-ignore-all-errors[56]
10+
11+
import io
12+
import logging
13+
import os
14+
import shutil
15+
from dataclasses import dataclass
16+
from pathlib import Path
17+
from typing import BinaryIO, Union
18+
19+
import torch
20+
21+
logger: logging.Logger = logging.getLogger(__name__)
22+
23+
24+
@dataclass(frozen=True)
25+
class FileStore:
26+
"""
27+
A basic file store implementation for easy data reads / writes / deletes.
28+
29+
This class is intended to be used as a utility inside the FBGEMM_GPU codebase
30+
for consistent writing of tensors and other objects to the filesystem.
31+
32+
Attribute:
33+
bucket (str): A directory in the filesystem.
34+
"""
35+
36+
bucket: str
37+
38+
def __post_init__(self) -> None:
39+
# self.bucket = bucket
40+
41+
if not os.path.isdir(self.bucket):
42+
raise ValueError(f"Directory {self.bucket} does not exist")
43+
44+
def write(
45+
self,
46+
path: str,
47+
raw_input: Union[BinaryIO, torch.Tensor, Path],
48+
ttls: int = 864000,
49+
) -> "FileStore":
50+
"""
51+
Writes a binary stream, or a torch.Tensor to the file located at `path`
52+
(relative to `self.bucket`).
53+
54+
Args:
55+
path (str): The path of the node or symlink to a directory.
56+
raw_input (BinaryIO | torch.Tensor | Path): The data to write.
57+
58+
ttls (int): The time to live for the data in seconds. Defaults to
59+
10 days.
60+
61+
Returns:
62+
self. This allows for method-chaining.
63+
"""
64+
65+
filepath = f"{self.bucket}/{path}"
66+
event = f"writing to {filepath}"
67+
logger.info(f"FileStore: {event}")
68+
69+
try:
70+
if os.path.isfile(filepath):
71+
raise FileExistsError(
72+
f"File {filepath} already exists in the filesystem"
73+
)
74+
75+
if isinstance(raw_input, torch.Tensor):
76+
torch.save(raw_input, filepath)
77+
78+
elif isinstance(raw_input, Path):
79+
if not os.path.exists(raw_input):
80+
raise FileNotFoundError(f"File {raw_input} does not exist")
81+
shutil.copyfile(raw_input, filepath)
82+
83+
elif isinstance(raw_input, io.BytesIO) or isinstance(raw_input, BinaryIO):
84+
with open(filepath, "wb") as file:
85+
raw_input.seek(0)
86+
while chunk := raw_input.read(4096): # Read 4 KB at a time
87+
file.write(chunk)
88+
else:
89+
raise TypeError(f"Unsupported input type: {type(raw_input)}")
90+
91+
except Exception as e:
92+
logger.error(f"FileStore: exception occurred when {event}: {e}")
93+
raise e
94+
95+
return self
96+
97+
def read(self, path: str) -> io.BytesIO:
98+
"""
99+
Reads a file into a BytesIO object.
100+
101+
Args:
102+
path (str): The path of the node or symlink to a directory (relative
103+
to `self.bucket`) to be read.
104+
105+
Returns:
106+
Data from the file in BytesIO object format.
107+
"""
108+
filepath = f"{self.bucket}/{path}"
109+
event = f"reading from {filepath}"
110+
logger.info(f"FileStore: {event}")
111+
112+
try:
113+
if not os.path.isfile(filepath):
114+
raise FileNotFoundError(
115+
f"File {filepath} does not exist in the FileStore"
116+
)
117+
118+
return io.BytesIO(open(filepath, "rb").read())
119+
120+
except Exception as e:
121+
logger.error(f"FileStore: exception occurred when {event}: {e}")
122+
raise e
123+
124+
def remove(self, path: str) -> "FileStore":
125+
"""
126+
Removes a file or directory from the file store.
127+
128+
Args:
129+
path (str): The path of the node or symlink to a directory (relative
130+
to `self.bucket`) to be removed.
131+
132+
Returns:
133+
self. This allows for method-chaining.
134+
"""
135+
filepath = f"{self.bucket}/{path}"
136+
event = f"deleting {filepath}"
137+
logger.info(f"FileStore: {event}")
138+
139+
try:
140+
if os.path.isfile(filepath):
141+
os.remove(filepath)
142+
143+
except Exception as e:
144+
logger.error(f"Manifold: exception occurred when {event}: {e}")
145+
raise e
146+
147+
return self
148+
149+
def exists(self, path: str) -> bool:
150+
"""
151+
Checks for existence of file in the file store.
152+
153+
Args:
154+
path (str): The Manifold target path (relative to `self.bucket`).
155+
156+
Returns:
157+
True if file exists, False otherwise.
158+
"""
159+
filepath = f"{self.bucket}/{path}"
160+
return os.path.isfile(filepath)
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import io
11+
import random
12+
import string
13+
import tempfile
14+
import unittest
15+
from pathlib import Path
16+
from typing import BinaryIO, Optional, Union
17+
18+
import fbgemm_gpu
19+
import torch
20+
21+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
22+
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
23+
24+
25+
class FileStoreTest(unittest.TestCase):
26+
def _test_filestore_readwrite(
27+
self,
28+
# pyre-fixme[2]
29+
store, # FileStore
30+
input: Union[BinaryIO, torch.Tensor, Path],
31+
path: Optional[str] = None,
32+
) -> None:
33+
"""
34+
Generic FileStore routines to test reading and writing data
35+
36+
Args:
37+
store (FileStore): The FileStore to test
38+
input (torch.Tensor | BinaryIO | Path): The data to write to the FileStore
39+
path (str, optional): The path to write the data to. If not provided, a random path will be generated.
40+
"""
41+
if path is None:
42+
path = "".join(random.choices(string.ascii_letters, k=15))
43+
44+
assert not store.exists(path), f"{path} already exists"
45+
store.write(path, input)
46+
assert store.exists(path), f"{path} does not exist"
47+
48+
if isinstance(input, torch.Tensor):
49+
assert torch.load(store.read(path)).equal(input), "tensors do not match"
50+
51+
elif isinstance(input, io.BytesIO) or isinstance(input, BinaryIO):
52+
input.seek(0)
53+
assert store.read(path).read() == input.read(), "byte streams do not match"
54+
55+
elif isinstance(input, Path):
56+
assert (
57+
store.read(path).read() == input.read_bytes()
58+
), "file contents do not match"
59+
60+
store.remove(path)
61+
assert not store.exists(path), f"{path} is not removed"
62+
63+
def test_filestore_oss_bad_bucket(self) -> None:
64+
"""
65+
Test that OSS FileStore raises ValueError when an invalid bucket is provided
66+
"""
67+
from fbgemm_gpu.utils import FileStore
68+
69+
self.assertRaises(
70+
ValueError, FileStore, "".join(random.choices(string.ascii_letters, k=15))
71+
)
72+
73+
def test_filestore_oss_binaryio(self) -> None:
74+
"""
75+
Test that OSS FileStore can read and write binary data
76+
"""
77+
from fbgemm_gpu.utils import FileStore
78+
79+
self._test_filestore_readwrite(
80+
FileStore("/tmp"),
81+
io.BytesIO("".join(random.choices(string.ascii_letters, k=128)).encode()),
82+
)
83+
84+
def test_filestore_oss_tensor(self) -> None:
85+
"""
86+
Test that OSS FileStore can read and write tensors
87+
"""
88+
from fbgemm_gpu.utils import FileStore
89+
90+
self._test_filestore_readwrite(
91+
FileStore("/tmp"),
92+
torch.rand((random.randint(100, 1000), random.randint(100, 1000))),
93+
)
94+
95+
def test_filestore_oss_file(self) -> None:
96+
"""
97+
Test that OSS FileStore can read and write files
98+
"""
99+
from fbgemm_gpu.utils import FileStore
100+
101+
input = torch.rand((random.randint(100, 1000), random.randint(100, 1000)))
102+
infile = tempfile.NamedTemporaryFile()
103+
torch.save(input, infile)
104+
105+
self._test_filestore_readwrite(FileStore("/tmp"), Path(infile.name))
106+
107+
@unittest.skipIf(open_source, "Test does not apply to OSS")
108+
def test_filestore_fb_bad_bucket(self) -> None:
109+
"""
110+
Test that FB FileStore raises ValueError when an invalid bucket is provided
111+
"""
112+
from fbgemm_gpu.fb.utils import FileStore
113+
114+
self.assertRaises(
115+
ValueError, FileStore, "".join(random.choices(string.ascii_letters, k=15))
116+
)
117+
118+
@unittest.skipIf(open_source, "Test does not apply to OSS")
119+
def test_filestore_fb_binaryio(self) -> None:
120+
"""
121+
Test that FB FileStore can read and write binary data
122+
"""
123+
from fbgemm_gpu.fb.utils import FileStore
124+
125+
self._test_filestore_readwrite(
126+
FileStore("tlparse_reports"),
127+
io.BytesIO("".join(random.choices(string.ascii_letters, k=128)).encode()),
128+
f"tree/{''.join(random.choices(string.ascii_letters, k=15))}.unittest",
129+
)
130+
131+
@unittest.skipIf(open_source, "Test does not apply to OSS")
132+
def test_filestore_fb_tensor(self) -> None:
133+
"""
134+
Test that FB FileStore can read and write tensors
135+
"""
136+
from fbgemm_gpu.fb.utils import FileStore
137+
138+
self._test_filestore_readwrite(
139+
FileStore("tlparse_reports"),
140+
torch.rand((random.randint(100, 1000), random.randint(100, 1000))),
141+
f"tree/{''.join(random.choices(string.ascii_letters, k=15))}.unittest",
142+
)
143+
144+
@unittest.skipIf(open_source, "Test does not apply to OSS")
145+
def test_filestore_fb_file(self) -> None:
146+
"""
147+
Test that FB FileStore can read and write files
148+
"""
149+
from fbgemm_gpu.fb.utils import FileStore
150+
151+
input = torch.rand((random.randint(100, 1000), random.randint(100, 1000)))
152+
infile = tempfile.NamedTemporaryFile()
153+
torch.save(input, infile)
154+
155+
self._test_filestore_readwrite(
156+
FileStore("tlparse_reports"),
157+
Path(infile.name),
158+
f"tree/{''.join(random.choices(string.ascii_letters, k=15))}.unittest",
159+
)

0 commit comments

Comments
 (0)