diff --git a/hub/api/dataset.py b/hub/api/dataset.py index 470e0e75c4..a564cb90ab 100644 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -79,6 +79,7 @@ def __init__( tokenizer=None, lazy: bool = True, public: bool = True, + name: str = None, ): """| Open a new or existing dataset for read/write @@ -113,6 +114,8 @@ def __init__( only applicable if using hub storage, ignored otherwise setting this to False allows only the user who created it to access the dataset and the dataset won't be visible in the visualizer to the public + name: str, optional + only applicable when using hub storage, this is the name that shows up on the visualizer """ shape = norm_shape(shape) @@ -128,6 +131,7 @@ def __init__( self._mode = mode self.tokenizer = tokenizer self.lazy = lazy + self._name = name self._fs, self._path = ( (fs, url) if fs else get_fs_and_path(self._url, token=token, public=public) @@ -147,6 +151,7 @@ def __init__( self.dataset_name = None if not needcreate: self.meta = json.loads(fs_map["meta.json"].decode("utf-8")) + self._name = self.meta.get("name") or None self._shape = tuple(self.meta["shape"]) self._schema = hub.schema.deserialize.deserialize(self.meta["schema"]) self._meta_information = self.meta.get("meta_info") or dict() @@ -215,6 +220,10 @@ def url(self): def shape(self): return self._shape + @property + def name(self): + return self._name + @property def token(self): return self._token @@ -241,6 +250,7 @@ def _store_meta(self) -> dict: "schema": hub.schema.serialize.serialize(self._schema), "version": 1, "meta_info": self._meta_information or dict(), + "name": self._name, } self._fs_map["meta.json"] = bytes(json.dumps(meta), "utf-8") @@ -483,6 +493,12 @@ def append_shape(self, size: int): size += self._shape[0] self.resize_shape(size) + def rename(self, name: str) -> None: + """ Renames the dataset """ + self._name = name + self.meta = self._store_meta() + self.flush() + def delete(self): """ Deletes the dataset """ fs, path = self._fs, self._path diff --git a/hub/api/tests/test_dataset.py b/hub/api/tests/test_dataset.py index a5dc4d6ecb..512dfd84e3 100644 --- a/hub/api/tests/test_dataset.py +++ b/hub/api/tests/test_dataset.py @@ -648,6 +648,20 @@ def test_dataset_assign_value(): assert ds["text", 6].compute() == "YGFJN75NF" +def test_dataset_name(): + schema = {"temp": "uint8"} + ds = Dataset( + "./data/test_ds_name", shape=(10,), schema=schema, name="my_dataset", mode="w" + ) + ds.flush() + assert ds.name == "my_dataset" + ds2 = Dataset("./data/test_ds_name") + ds2.rename("my_dataset_2") + assert ds2.name == "my_dataset_2" + ds3 = Dataset("./data/test_ds_name") + assert ds3.name == "my_dataset_2" + + if __name__ == "__main__": # test_pickleability() test_pickleability()