diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 324cdc17c9dd..66bfe99d85df 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -78,7 +78,7 @@ def save_pretrained(self, save_directory): self.to_json_file(output_config_file) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): r""" Instantiate a PretrainedConfig from a pre-trained model configuration. Params: @@ -91,20 +91,33 @@ def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): **cache_dir**: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. + **return_unused_kwargs**: (`optional`) bool: + - If False, then this function returns just the final configuration object. + - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` + is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: + ie the part of kwargs which has not been used to update `config` and is otherwise ignored. **kwargs**: (`optional`) dict: - Dictionnary of key, values to update the configuration object after loading. - Can be used to override selected configuration parameters. + Dictionary of key/value pairs with which to update the configuration object after loading. + - The values in kwargs of any keys which are configuration attributes will be used + to override the loaded values. + - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled + by the `return_unused_kwargs` keyword parameter. Examples:: >>> config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. >>> config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` >>> config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') - >>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True) + >>> config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) >>> assert config.output_attention == True + >>> config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, + >>> foo=False, return_unused_kwargs=True) + >>> assert config.output_attention == True + >>> assert unused_kwargs == {'foo': False} """ cache_dir = kwargs.pop('cache_dir', None) + return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) if pretrained_model_name_or_path in cls.pretrained_config_archive_map: config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] @@ -148,7 +161,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *input, **kwargs): kwargs.pop(key, None) logger.info("Model config %s", config) - return config + if return_unused_kwargs: + return config, kwargs + else: + return config @classmethod def from_dict(cls, json_object): @@ -305,7 +321,7 @@ def save_pretrained(self, save_directory): torch.save(model_to_save.state_dict(), output_model_file) @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are desactivated) @@ -322,6 +338,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): provided as `config` argument. This loading option is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + **model_args**: (`optional`) Sequence: + All remaning positional arguments will be passed to the underlying model's __init__ function **config**: an optional configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or @@ -337,8 +355,17 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): **output_loading_info**: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. **kwargs**: (`optional`) dict: - Dictionnary of key, values to update the configuration object after loading. - Can be used to override selected configuration parameters. E.g. ``output_attention=True`` + Dictionary of key, values to update the configuration object after loading. + Can be used to override selected configuration parameters. E.g. ``output_attention=True``. + + - If a configuration is provided with `config`, **kwargs will be directly passed + to the underlying model's __init__ method. + - If a configuration is not provided, **kwargs will be first passed to the pretrained + model configuration class loading function (`PretrainedConfig.from_pretrained`). + Each key of **kwargs that corresponds to a configuration attribute + will be used to override said attribute with the supplied **kwargs value. + Remaining keys that do not correspond to any configuration attribute will + be passed to the underlying model's __init__ function. Examples:: @@ -359,7 +386,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): # Load config if config is None: - config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + config, model_kwargs = cls.config_class.from_pretrained( + pretrained_model_name_or_path, *model_args, + cache_dir=cache_dir, return_unused_kwargs=True, + **kwargs + ) + else: + model_kwargs = kwargs # Load model if pretrained_model_name_or_path in cls.pretrained_model_archive_map: @@ -400,7 +433,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): archive_file, resolved_archive_file)) # Instantiate model. - model = cls(config) + model = cls(config, *model_args, **model_kwargs) if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file, map_location='cpu') @@ -530,7 +563,7 @@ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask **start_states**: ``torch.LongTensor`` of shape identical to hidden_states hidden states of the first tokens for the labeled span. **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` - position of the first token for the labeled span: + position of the first token for the labeled span: **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` Mask of invalid position such as query and special symbols (PAD, SEP, CLS) 1.0 means token should be masked. @@ -717,7 +750,7 @@ class SequenceSummary(nn.Module): - 'attn' => Not implemented now, use multi-head attention summary_use_proj: Add a projection after the vector extraction summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. - summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default + summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default summary_first_dropout: Add a dropout before the projection and activation summary_last_dropout: Add a dropout after the projection and activation """