-
-
Notifications
You must be signed in to change notification settings - Fork 200
Expand file tree
/
Copy pathgo1_sft_libero.py
More file actions
72 lines (60 loc) · 2.58 KB
/
go1_sft_libero.py
File metadata and controls
72 lines (60 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
from dataclasses import dataclass, field
from typing import List, Optional
from transformers import TrainingArguments
from go1.configs.go1_base_cfg import BaseDatasetArguments, BaseModelArguments, BaseSpaceArguments
from go1.tools.env_parse import get_bool_env
RUNNAME = os.environ.get("RUNNAME")
DEBUG_MODE = get_bool_env("DEBUG_MODE")
@dataclass
class DatasetArguments(BaseDatasetArguments):
dataset_type: Optional[str] = field(default="lerobot")
data_root_dir: Optional[List[str]] = field(
default_factory=lambda: [
"/path/to/your/libero/dataset",
],
)
transforms: Optional[List[str]] = field(default_factory=lambda: [dict(type="Normalize")])
@dataclass
class GOModelArguments(BaseModelArguments):
model_name_or_path: str = field(default="agibot-world/GO-1")
freeze_llm: bool = field(default=False if not DEBUG_MODE else True)
freeze_backbone: bool = field(default=False if not DEBUG_MODE else True)
freeze_mlp: bool = field(default=False if not DEBUG_MODE else True)
action_chunk_size: int = field(default=10)
latent_planning: bool = field(default=True)
freeze_latent_planner: bool = field(default=False)
@dataclass
class GOTrainingArguments(TrainingArguments):
output_dir: str = field(default=f"experiment/{RUNNAME}")
overwrite_output_dir: bool = field(default=True)
dataloader_num_workers: int = field(default=20 if not DEBUG_MODE else 0)
bf16: bool = field(default=True)
num_train_epochs: float = field(default=100.0)
per_device_train_batch_size: int = field(default=16 if not DEBUG_MODE else 2)
gradient_accumulation_steps: int = field(default=1)
learning_rate: float = field(default=2e-5)
weight_decay: float = field(default=0.01)
lr_scheduler_type: str = field(default="cosine")
warmup_steps: int = field(default=1000)
do_train: bool = field(default=True)
deepspeed: str = field(default="go1/zero_stage1_config.json")
save_strategy: str = field(default="steps")
save_steps: int = field(default=10000)
save_total_limit: int = field(default=100)
logging_steps: int = field(default=10)
report_to: str = field(default="tensorboard")
@dataclass
class SpaceArguments(BaseSpaceArguments):
state_dim: int = field(default=8)
action_dim: int = field(default=7)
space_repack: dict = field(
default_factory=lambda: {
"state": "state",
"action": "actions",
"cam_head_color": "image",
"cam_hand_left_color": "wrist_image",
"final_prompt": "task",
}
)
ctrl_freq: int = field(default=10)