Welcome to TFSnippet

TFSnippet is a set of utilities for writing and testing TensorFlow models.

The design philosophy of TFSnippet is non-interfering. It aims to provide a set of useful utilities, possible to be used along with any other TensorFlow libraries and frameworks.

Installation

pip install git+https://github.com/thu-ml/zhusuan.git
pip install git+https://github.com/haowen-xu/tfsnippet.git

Documentation

API Docs

tfsnippet

tfsnippet Package
Functions
as_distribution(distribution) Convert a supported type of distribution into Distribution type.
reduce_group_ndims(operation, tensor, …[, …]) Reduce the last group_ndims dimensions in tensor, using operation.
summarize_variables(variables[, title, …]) Get a formatted summary about the variables.
auto_batch_weight(*batch_arrays) Automatically inspect the metric weight for an evaluation mini-batch.
merge_feed_dict(*feed_dicts) Merge all feed dicts into one.
resolve_feed_dict(feed_dict[, inplace]) Resolve all dynamic values in feed_dict into fixed values.
elbo_objective(log_joint, latent_log_prob[, …]) Derive the ELBO objective.
importance_sampling_log_likelihood(…[, …]) Compute \(\log p(\mathbf{x})\) by importance sampling.
iwae_estimator(log_values, axis[, keepdims, …]) Derive the gradient estimator for \(\mathbb{E}_{q(\mathbf{z}^{(1:K)}|\mathbf{x})}\Big[\log \frac{1}{K} \sum_{k=1}^K f\big(\mathbf{x},\mathbf{z}^{(k)}\big)\Big]\), by IWAE (Burda, Y., Grosse, R.
monte_carlo_objective(log_joint, latent_log_prob) Derive the Monte-Carlo objective.
nvil_estimator(values, latent_log_joint[, …]) Derive the gradient estimator for \(\mathbb{E}_{q(\mathbf{z}|\mathbf{x})}\big[f(\mathbf{x},\mathbf{z})\big]\), by NVIL (Mnih and Gregor, 2014) algorithm.
sgvb_estimator(values[, axis, keepdims, name]) Derive the gradient estimator for \(\mathbb{E}_{q(\mathbf{z}|\mathbf{x})}\big[f(\mathbf{x},\mathbf{z})\big]\), by SGVB (Kingma, D.P.
vimco_estimator(log_values, latent_log_joint) Derive the gradient estimator for
get_config_defaults(config) Get the default config values of config.
register_config_arguments(config, parser[, …]) Register config to the specified argument parser.
model_variable(name[, shape, dtype, …]) Get or create a model variable.
get_model_variables([scope]) Get all model variables (i.e., variables in MODEL_VARIABLES collection).
instance_reuse([method_or_scope, _sentinel, …]) Decorate an instance method to reuse a variable scope automatically.
global_reuse([method_or_scope, _sentinel, scope]) Decorate a function to reuse a variable scope automatically.
add_histogram(tensor[, summary_name, …]) Add the histogram of tensor to the default summary collector, and to collections.
add_summary(summary[, collections]) Add the summary to the default summary collector, and to collections.
default_summary_collector() Get the SummaryCollector object at the top of context stack.
Classes
BatchToValueDistribution(distribution, ndims) Distribution that converts the last few batch_ndims into values_ndims.
Bernoulli(logits[, dtype]) Univariate Bernoulli distribution.
Categorical(logits[, dtype]) Univariate Categorical distribution.
Concrete(temperature, logits[, …]) The class of Concrete (or Gumbel-Softmax) distribution from (Maddison, 2016; Jang, 2016), served as the continuous relaxation of the OnehotCategorical.
Discrete alias of tfsnippet.distributions.univariate.Categorical
DiscretizedLogistic(mean, log_scale, bin_size) Discretized logistic distribution (Kingma et.
Distribution(dtype, is_continuous, …) Base class for probability distributions.
ExpConcrete(temperature, logits[, …]) The class of ExpConcrete distribution from (Maddison, 2016), transformed from Concrete by taking logarithm.
FlowDistribution(distribution, flow) Transform a Distribution by a BaseFlow, as a new distribution.
FlowDistributionDerivedTensor(tensor, …) A combination of a FlowDistribution derived tensor, and its original stochastic tensor from the base distribution.
Mixture(categorical, components[, …]) Mixture distribution.
Normal(mean[, std, logstd, …]) Univariate Normal distribution.
OnehotCategorical(logits[, dtype]) One-hot multivariate Categorical distribution.
Uniform([minval, maxval, …]) Univariate Uniform distribution.
AnnealingVariable(name, initial_value, ratio) A non-trainable tf.Variable, whose value will be annealed as training goes by.
CheckpointSavableObject Base class for all objects that can be saved via CheckpointSaver.
CheckpointSaver(variables, save_dir[, …]) Save and restore tf.Variable, ScheduledVariable and CheckpointSavableObject with tf.train.Saver.
DefaultMetricFormatter Default training metric formatter.
EventKeys Defines event keys for TFSnippet.
MetricFormatter Base class for a training metrics formatter.
MetricLogger([summary_writer, …]) Logger for the training metrics.
ScheduledVariable(name, initial_value[, …]) A non-trainable tf.Variable, whose value might need to be changed as training goes by.
TrainLoop(param_vars[, var_groups, …]) Training loop object.
AnnealingScalar(loop, initial_value, ratio) A DynamicValue scalar, which anneals every few epochs or steps.
BaseTrainer(loop[, ensure_variables_initialized]) Base class for all trainers.
DynamicValue Dynamic values to be fed into trainers and evaluators.
Evaluator(loop, metrics, inputs, data_flow) Class to compute evaluation metrics.
LossTrainer(**kwargs) A subclass of BaseTrainer, which optimizes a single loss.
Trainer(loop, train_op, inputs, data_flow[, …]) A subclass of BaseTrainer, executing a training operation per step.
Validator(**kwargs) Class to compute validation loss and other metrics.
VariationalChain(variational, model[, …]) Chain of the variational and model nets for variational inference.
VariationalEvaluation(vi) Factory for variational evaluation outputs.
VariationalInference(log_joint, latent_log_probs) Class for variational inference.
VariationalLowerBounds(vi) Factory for variational lower-bounds.
VariationalTrainingObjectives(vi) Factory for variational training objectives.
BayesianNet([observed]) Bayesian networks.
DataFlow Data flows are objects for constructing mini-batch iterators.
DataMapper Base class for all data mappers.
SlidingWindow(data_array, window_size) DataMapper for producing sliding windows according to indices.
Config Base class for defining config values.
ConfigField(type[, default, description, …]) A config field.
GraphKeys Defines TensorFlow graph collection keys for TFSnippet.
InvertibleMatrix(size[, strict, dtype, …]) A matrix initialized to be an invertible, orthogonal matrix.
VarScopeObject([name, scope]) Base class for objects that own a variable scope.
SummaryCollector([collections, …]) Collecting summaries and histograms added by tfsnippet.add_summary() and tfsnippet.add_histogram().
StochasticTensor(distribution, tensor[, …]) Samples or observations of a stochastic variable.
Class Inheritance Diagram
Inheritance diagram of tfsnippet.distributions.batch_to_value.BatchToValueDistribution, tfsnippet.distributions.univariate.Bernoulli, tfsnippet.distributions.univariate.Categorical, tfsnippet.distributions.multivariate.Concrete, tfsnippet.distributions.univariate.Categorical, tfsnippet.distributions.discretized.DiscretizedLogistic, tfsnippet.distributions.base.Distribution, tfsnippet.distributions.multivariate.ExpConcrete, tfsnippet.distributions.flow.FlowDistribution, tfsnippet.distributions.flow.FlowDistributionDerivedTensor, tfsnippet.distributions.mixture.Mixture, tfsnippet.distributions.univariate.Normal, tfsnippet.distributions.multivariate.OnehotCategorical, tfsnippet.distributions.univariate.Uniform, tfsnippet.scaffold.scheduled_var.AnnealingVariable, tfsnippet.scaffold.checkpoint.CheckpointSavableObject, tfsnippet.scaffold.checkpoint.CheckpointSaver, tfsnippet.scaffold.logging_.DefaultMetricFormatter, tfsnippet.scaffold.event_keys.EventKeys, tfsnippet.scaffold.logging_.MetricFormatter, tfsnippet.scaffold.logging_.MetricLogger, tfsnippet.scaffold.scheduled_var.ScheduledVariable, tfsnippet.scaffold.train_loop_.TrainLoop, tfsnippet.trainer.dynamic_values.AnnealingScalar, tfsnippet.trainer.base_trainer.BaseTrainer, tfsnippet.trainer.dynamic_values.DynamicValue, tfsnippet.trainer.evaluator.Evaluator, tfsnippet.trainer.loss_trainer.LossTrainer, tfsnippet.trainer.trainer.Trainer, tfsnippet.trainer.validator.Validator, tfsnippet.variational.chain.VariationalChain, tfsnippet.variational.inference.VariationalEvaluation, tfsnippet.variational.inference.VariationalInference, tfsnippet.variational.inference.VariationalLowerBounds, tfsnippet.variational.inference.VariationalTrainingObjectives, tfsnippet.bayes.BayesianNet, tfsnippet.dataflows.base.DataFlow, tfsnippet.dataflows.data_mappers.DataMapper, tfsnippet.dataflows.data_mappers.SlidingWindow, tfsnippet.utils.config_utils.Config, tfsnippet.utils.config_utils.ConfigField, tfsnippet.utils.graph_keys.GraphKeys, tfsnippet.utils.invertible_matrix.InvertibleMatrix, tfsnippet.utils.reuse.VarScopeObject, tfsnippet.utils.summary_collector.SummaryCollector, tfsnippet.stochastic.StochasticTensor

tfsnippet.dataflows

tfsnippet.dataflows Package
Classes
ArrayFlow(arrays, batch_size[, shuffle, …]) Using numpy-like arrays as data source flow.
DataFlow Data flows are objects for constructing mini-batch iterators.
DataMapper Base class for all data mappers.
ExtraInfoDataFlow(array_count, data_length, …) Base class for DataFlow subclasses with auxiliary information about the mini-batches.
GatherFlow(flows) Gathering multiple data flows into a single flow.
IteratorFactoryFlow(factory) Data flow constructed from an iterator factory.
MapperFlow(source, mapper[, array_indices]) Data flow which transforms the mini-batch arrays from source flow by a specified mapper function.
SeqFlow(start, stop[, step, batch_size, …]) Using number sequence as data source flow.
SlidingWindow(data_array, window_size) DataMapper for producing sliding windows according to indices.
ThreadingFlow(source, prefetch) Data flow to prefetch from the source data flow in a background thread.
Class Inheritance Diagram
Inheritance diagram of tfsnippet.dataflows.array_flow.ArrayFlow, tfsnippet.dataflows.base.DataFlow, tfsnippet.dataflows.data_mappers.DataMapper, tfsnippet.dataflows.base.ExtraInfoDataFlow, tfsnippet.dataflows.gather_flow.GatherFlow, tfsnippet.dataflows.iterator_flow.IteratorFactoryFlow, tfsnippet.dataflows.mapper_flow.MapperFlow, tfsnippet.dataflows.seq_flow.SeqFlow, tfsnippet.dataflows.data_mappers.SlidingWindow, tfsnippet.dataflows.threading_flow.ThreadingFlow

tfsnippet.datasets

tfsnippet.datasets Package
Functions
load_cifar10([channels_last, x_shape, …]) Load the CIFAR-10 dataset as NumPy arrays.
load_cifar100([label_mode, channels_last, …]) Load the CIFAR-100 dataset as NumPy arrays.
load_fashion_mnist([x_shape, x_dtype, …]) Load the Fashion MNIST dataset as NumPy arrays.
load_mnist([x_shape, x_dtype, y_dtype, …]) Load the MNIST dataset as NumPy arrays.

tfsnippet.layers

tfsnippet.layers Package
Functions
act_norm(*args, **kwargs) ActNorm proposed by (Kingma & Dhariwal, 2018).
as_gated(layer_fn[, sigmoid_bias, default_name]) Wrap a layer function into a gated layer function.
avg_pool2d(*args, **kwargs) 2D average pooling over spatial dimensions.
broadcast_log_det_against_input(log_det, …) Broadcast the shape of log_det to match the shape of input.
conv2d(*args, **kwargs) 2D convolutional layer.
deconv2d(*args, **kwargs) 2D deconvolutional layer.
default_kernel_initializer([weight_norm]) Get the default initializer for layer kernels (i.e., W of layers).
dense(*args, **kwargs) Fully-connected layer.
dropout(*args, **kwargs) Apply dropout on input.
global_avg_pool2d(*args, **kwargs) 2D global average pooling over spatial dimensions.
l2_regularizer(lambda_[, name]) Construct an L2 regularizer that computes the L2 regularization loss.
max_pool2d(*args, **kwargs) 2D max pooling over spatial dimensions.
pixelcnn_2d_input(*args, **kwargs) Prepare the input for a PixelCNN 2D network (Tim Salimans, 2017).
pixelcnn_2d_output(input) Get the final output of a PixelCNN 2D network from the previous layer.
pixelcnn_conv2d_resnet(*args, **kwargs) PixelCNN 2D convolutional ResNet block.
planar_normalizing_flows([n_layers, …]) Construct a sequential of :class`PlanarNormalizingFlow`.
resnet_conv2d_block(*args, **kwargs) 2D convolutional ResNet block.
resnet_deconv2d_block(*args, **kwargs) 2D deconvolutional ResNet block.
resnet_general_block(*args, **kwargs) A general implementation of ResNet block.
shifted_conv2d(*args, **kwargs) 2D convolution with shifted input.
weight_norm(*args, **kwargs) Weight normalization proposed by (Salimans & Kingma, 2016).
Classes
ActNorm([axis, value_ndims, initialized, …]) ActNorm proposed by (Kingma & Dhariwal, 2018).
BaseFlow(x_value_ndims[, y_value_ndims, …]) The basic class for normalizing flows.
BaseLayer([name, scope]) Base class for all neural network layers.
CouplingLayer(shift_and_scale_fn[, axis, …]) A general implementation of the coupling layer (Dinh et al., 2016).
FeatureMappingFlow(axis, value_ndims, **kwargs) Base class for flows mapping input features to output features.
FeatureShufflingFlow([axis, value_ndims, …]) An invertible flow which shuffles the order of input features.
InvertFlow(flow[, name, scope]) Turn a BaseFlow into its inverted flow.
InvertibleActivation Base class for intertible activation functions.
InvertibleActivationFlow(activation, value_ndims) A flow that converts a InvertibleActivation into a flow.
InvertibleConv2d([channels_last, …]) Invertible 1x1 2D convolution proposed in (Kingma & Dhariwal, 2018).
InvertibleDense([strict_invertible, …]) Invertible dense layer, modified from the invertible 1x1 2d convolution proposed in (Kingma & Dhariwal, 2018).
LeakyReLU([alpha]) Leaky ReLU activation function.
MultiLayerFlow(n_layers, **kwargs) Base class for multi-layer normalizing flows.
PixelCNN2DOutput(vertical, horizontal) The output of a PixelCNN 2D layer, including tensors from the vertical and horizontal convolution stacks.
PlanarNormalizingFlow([w_initializer, …]) A single layer Planar Normalizing Flow (Danilo 2016) with tanh activation function, as well as the invertible trick.
ReshapeFlow(x_value_ndims, y_value_shape[, …]) A flow which reshapes the last x_value_ndims of x into y_value_shape.
SequentialFlow(flows[, name, scope]) Compose a large flow from a sequential of BaseFlow.
SpaceToDepthFlow(block_size[, …]) A flow which computes y = space_to_depth(x), and conversely x = depth_to_space(y).
SplitFlow(split_axis, left[, join_axis, …]) A flow which splits input x into halves, apply different flows on each half, then concat the output together.
Class Inheritance Diagram
Inheritance diagram of tfsnippet.layers.normalization.act_norm_.ActNorm, tfsnippet.layers.flows.base.BaseFlow, tfsnippet.layers.base.BaseLayer, tfsnippet.layers.flows.coupling.CouplingLayer, tfsnippet.layers.flows.base.FeatureMappingFlow, tfsnippet.layers.flows.rearrangement.FeatureShufflingFlow, tfsnippet.layers.flows.invert.InvertFlow, tfsnippet.layers.activations.base.InvertibleActivation, tfsnippet.layers.activations.base.InvertibleActivationFlow, tfsnippet.layers.flows.linear.InvertibleConv2d, tfsnippet.layers.flows.linear.InvertibleDense, tfsnippet.layers.activations.leaky_relu.LeakyReLU, tfsnippet.layers.flows.base.MultiLayerFlow, tfsnippet.layers.convolutional.pixelcnn.PixelCNN2DOutput, tfsnippet.layers.flows.planar_nf.PlanarNormalizingFlow, tfsnippet.layers.flows.reshape.ReshapeFlow, tfsnippet.layers.flows.sequential.SequentialFlow, tfsnippet.layers.flows.reshape.SpaceToDepthFlow, tfsnippet.layers.flows.branch.SplitFlow

tfsnippet.ops

tfsnippet.ops Package
Functions
add_n_broadcast(tensors[, name]) Add zero or many tensors with broadcasting.
assert_rank(x, ndims[, message, name]) Assert the rank of x is ndims.
assert_rank_at_least(x, ndims[, message, name]) Assert the rank of x is at least ndims.
assert_scalar_equal(a, b[, message, name]) Assert 0-d scalar a == b.
assert_shape_equal(x, y[, message, name]) Assert the shape of x equals to y.
bits_per_dimension(log_p, value_size[, …]) Compute “bits per dimension” of x.
broadcast_concat(x, y, axis[, name]) Broadcast x and y, then concat them along axis.
broadcast_to_shape(x, shape[, name]) Broadcast x to match shape.
broadcast_to_shape_strict(x, shape[, name]) Broadcast x to match shape.
classification_accuracy(y_pred, y_true[, name]) Compute the classification accuracy for y_pred and y_true.
convert_to_tensor_and_cast(x[, dtype]) Convert x into a tf.Tensor, and cast its dtype if required.
depth_to_space(input, block_size[, …]) Wraps tf.depth_to_space(), to support tensors higher than 4-d.
flatten_to_ndims(x, ndims[, name]) Flatten the front dimensions of x, such that the resulting tensor will have at most ndims dimensions.
log_mean_exp(x[, axis, keepdims, name]) Compute \(\log \frac{1}{K} \sum_{k=1}^K \exp(x_k)\).
log_sum_exp(x[, axis, keepdims, name]) Compute \(\log \sum_{k=1}^K \exp(x_k)\).
maybe_clip_value(x[, min_val, max_val, name]) Maybe clip the elements of x.
pixelcnn_2d_sample(fn, inputs, height, width) Sample output from a PixelCNN 2D network, pixel-by-pixel.
prepend_dims(x[, ndims, name]) Prepend [1] * ndims to the beginning of the shape of x.
reshape_tail(input, ndims, shape[, name]) Reshape the tail (last) ndims into specified shape.
shift(input, shift[, name]) Shift each axis of input according to shift, but keep identical size.
smart_cond(cond, true_fn, false_fn[, name]) Execute true_fn or false_fn according to cond.
softmax_classification_output(logits[, name]) Get the most possible softmax classification output for each logit.
space_to_depth(input, block_size[, …]) Wraps tf.space_to_depth(), to support tensors higher than 4-d.
transpose_conv2d_axis(input, …[, name]) Ensure the channels axis of input tensor to be placed at the desired axis.
transpose_conv2d_channels_last_to_x(input, …) Ensure the channels axis (known to be the last axis) of input tensor to be placed at the desired axis.
transpose_conv2d_channels_x_to_last(input, …) Ensure the channels axis of input tensor to be placed at the last axis.
unflatten_from_ndims(x, static_front_shape, …) The inverse transformation of flatten().

tfsnippet.preprocessing

tfsnippet.preprocessing Package
Classes
BaseSampler Base class for samplers.
BernoulliSampler([dtype, random_state]) A DataMapper which can sample 0/1 integers according to the input probability.
UniformNoiseSampler([minval, maxval, dtype, …]) A DataMapper which can add uniform noise onto the input array.
Class Inheritance Diagram
Inheritance diagram of tfsnippet.preprocessing.samplers.BaseSampler, tfsnippet.preprocessing.samplers.BernoulliSampler, tfsnippet.preprocessing.samplers.UniformNoiseSampler

tfsnippet.utils

tfsnippet.utils Package
Functions
DocInherit(kclass) Class decorator to enable kclass and all its sub-classes to automatically inherit docstrings from base classes.
add_histogram(tensor[, summary_name, …]) Add the histogram of tensor to the default summary collector, and to collections.
add_name_and_scope_arg_doc(method) Add name and scope argument to the doc of method.
add_name_arg_doc(method) Add name argument to the doc of method.
add_summary(summary[, collections]) Add the summary to the default summary collector, and to collections.
append_arg_to_doc(doc, arg_doc) Add the doc for name and scope argument to the doc string.
append_to_doc(doc, content) Append content to the doc string.
assert_deps(*args, **kwds) If tfsnippet.settings.enable_assertions == True, open a context that will run assert_ops.
camel_to_underscore(name) Convert a camel-case name to underscore.
concat_shapes(shapes[, name]) Concat shapes from shapes.
create_session([lock_memory, …]) A convenient method to create a TensorFlow session.
default_summary_collector() Get the SummaryCollector object at the top of context stack.
deprecated_arg(old_arg[, new_arg, version])
ensure_variables_initialized([variables, name]) Ensure variables are initialized.
generate_random_seed() Generate a new random seed from the default NumPy random state.
get_batch_size(tensor[, axis, name]) Infer the mini-batch size according to tensor.
get_cache_root() Get the cache root directory.
get_config_defaults(config) Get the default config values of config.
get_config_validator(type) Get an instance of ConfigValidator for specified type.
get_default_scope_name(name[, cls_or_instance]) Generate a valid default scope name.
get_default_session_or_error() Get the default session.
get_dimension_size(tensor, axis[, name]) Get the size of tensor of specified axis.
get_dimensions_size(tensor[, axes, name]) Get the size of tensor of specified axes.
get_model_variables([scope]) Get all model variables (i.e., variables in MODEL_VARIABLES collection).
get_rank(tensor[, name]) Get the rank of the tensor.
get_reuse_stack_top() Get the top of the reuse scope stack.
get_static_shape(tensor) Get the the static shape of specified tensor as a tuple.
get_uninitialized_variables([variables, name]) Get uninitialized variables as a list.
get_variable_ddi(name, initial_value[, …]) Wraps tf.get_variable() to support data-dependent initialization.
get_variables_as_dict([scope, collection]) Get TensorFlow variables as dict.
global_reuse([method_or_scope, _sentinel, scope]) Decorate a function to reuse a variable scope automatically.
humanize_duration(seconds[, short_units]) Format specified time duration as human readable text.
instance_reuse([method_or_scope, _sentinel, …]) Decorate an instance method to reuse a variable scope automatically.
is_float(x) Test whether or not x is a Python or NumPy float.
is_integer(x) Test whether or not x is a Python or NumPy integer.
is_shape_equal(x, y[, name]) Check whether the shape of x equals to y.
is_tensor_object(x) Test whether or not x is a tensor object.
is_tensorflow_version_higher_or_equal(version) Check whether the version of TensorFlow is higher than or equal to version.
iter_files(root_dir[, sep]) Iterate through all files in root_dir, returning the relative paths of each file.
makedirs(name[, mode, exist_ok])
maybe_add_histogram(tensor[, summary_name, …]) If tfsnippet.settings.auto_histogram == True, add the histogram of tensor via tfsnippet.add_histogram().
maybe_check_numerics(tensor, message[, name]) If tfsnippet.settings.check_numerics == True, check the numerics of tensor.
maybe_close(*args, **kwds) Enter a context, and if obj has .close() method, close it when exiting the context.
minibatch_slices_iterator(length, batch_size) Iterate through all the mini-batch slices.
model_variable(name[, shape, dtype, …]) Get or create a model variable.
print_as_table(title, key_values[, hr]) Print a key-value sequence as a table.
register_config_arguments(config, parser[, …]) Register config to the specified argument parser.
register_config_validator(type, validator_class) Register a config value validator.
register_tensor_wrapper_class(cls) Register a sub-class of TensorWrapper into TensorFlow type system.
reopen_variable_scope(*args, **kwds) Reopen the specified var_scope and its original name scope.
resolve_negative_axis(ndims, axis) Resolve all negative axis indices according to ndims into positive.
root_variable_scope(*args, **kwds) Open the root variable scope and its name scope.
scoped_set_config(*args, **kwds) Set config values within a context scope.
set_cache_root(cache_root) Set the root cache directory.
set_random_seed(seed) Generate random seeds for NumPy, TensorFlow and TFSnippet.
split_numpy_array(array[, portion, size, …]) Split numpy array into two halves, by portion or by size.
split_numpy_arrays(arrays[, portion, size, …]) Split numpy arrays into two halves, by portion or by size.
validate_enum_arg(arg_name, arg_value, choices) Validate the value of a enumeration argument.
validate_group_ndims_arg(group_ndims[, name]) Validate the specified value for group_ndims argument.
validate_int_tuple_arg(arg_name, arg_value) Validate an integer or a tuple of integers, as a tuple of integers.
validate_n_samples_arg(value, name) Validate the n_samples argument.
validate_positive_int_arg(arg_name, arg_value) Validate a positive integer argument.
Classes
AutoInitAndCloseable Classes with init() to initialize its internal states, and also close() to destroy these states.
BaseRegistry([ignore_case]) A base class for implement a type or object registry.
BoolConfigValidator Config value validator for boolean values.
CacheDir(name[, cache_root]) Class to manipulate a cache directory.
ClassRegistry([ignore_case]) A subclass of BaseRegistry, dedicated for classes.
Config Base class for defining config values.
ConfigField(type[, default, description, …]) A config field.
ConfigValidator Base config value validator.
ConsoleTable(col_count[, col_space, …]) A class to help format a console table.
ContextStack([initial_factory]) A thread-local context stack for general purpose.
Disposable Classes which can only be used once.
DisposableContext Base class for contexts which can only be entered once.
ETA([take_initial_snapshot]) Class to help compute the Estimated Time Ahead (ETA).
EventSource([allowed_event_keys]) An object that may trigger events.
Extractor(archive_file) The base class for all archive extractors.
FloatConfigValidator Config value validator for float values.
GraphKeys Defines TensorFlow graph collection keys for TFSnippet.
InputSpec([shape, dtype]) Class to describe the specification for an input tensor.
IntConfigValidator Config value validator for integer values.
InvertibleMatrix(size[, strict, dtype, …]) A matrix initialized to be an invertible, orthogonal matrix.
NoReentrantContext Base class for contexts which are not reentrant (i.e., if there is a context opened by __enter__, and it has not called __exit__, the __enter__ cannot be called again).
ParamSpec(*args, **kwargs) Class to describe the specification for a parameter.
PermutationMatrix(data) A non-trainable permutation matrix.
RarExtractor(fpath) Extractor for “.rar” files.
StatisticsCollector([shape]) Computing \(\mathrm{E}[X]\) and \(\operatorname{Var}[X]\) online.
StrConfigValidator Config value validator for string values.
SummaryCollector([collections, …]) Collecting summaries and histograms added by tfsnippet.add_summary() and tfsnippet.add_histogram().
TFSnippetConfig Global configurations of TFSnippet.
TarExtractor(fpath) Extractor for “.tar”, “.tar.gz”, “.tgz”, “.tar.bz2”, “.tbz”, “.tbz2”, “.tb2”, “.tar.xz”, “.txz” files.
TemporaryDirectory([suffix, prefix, dir]) Create and return a temporary directory.
TensorArgValidator(name) Class to validate argument values of tensors.
TensorSpec([shape, dtype]) Base class to describe and validate the specification of a tensor.
TensorWrapper Tensor-like object that wraps a tf.Tensor instance.
VarScopeObject([name, scope]) Base class for objects that own a variable scope.
VarScopeRandomState(variable_scope) A sub-class of np.random.RandomState, which uses a variable-scope dependent seed.
ZipExtractor(fpath) Extractor for “.zip” files.
deprecated([message, version]) Decorate a class, a method or a function to be deprecated.
Class Inheritance Diagram
Inheritance diagram of tfsnippet.utils.concepts.AutoInitAndCloseable, tfsnippet.utils.registry.BaseRegistry, tfsnippet.utils.config_utils.BoolConfigValidator, tfsnippet.utils.caching.CacheDir, tfsnippet.utils.registry.ClassRegistry, tfsnippet.utils.config_utils.Config, tfsnippet.utils.config_utils.ConfigField, tfsnippet.utils.config_utils.ConfigValidator, tfsnippet.utils.console_table.ConsoleTable, tfsnippet.utils.misc.ContextStack, tfsnippet.utils.concepts.Disposable, tfsnippet.utils.concepts.DisposableContext, tfsnippet.utils.misc.ETA, tfsnippet.utils.events.EventSource, tfsnippet.utils.archive_file.Extractor, tfsnippet.utils.config_utils.FloatConfigValidator, tfsnippet.utils.graph_keys.GraphKeys, tfsnippet.utils.tensor_spec.InputSpec, tfsnippet.utils.config_utils.IntConfigValidator, tfsnippet.utils.invertible_matrix.InvertibleMatrix, tfsnippet.utils.concepts.NoReentrantContext, tfsnippet.utils.tensor_spec.ParamSpec, tfsnippet.utils.invertible_matrix.PermutationMatrix, tfsnippet.utils.archive_file.RarExtractor, tfsnippet.utils.statistics.StatisticsCollector, tfsnippet.utils.config_utils.StrConfigValidator, tfsnippet.utils.summary_collector.SummaryCollector, tfsnippet.utils.settings_.TFSnippetConfig, tfsnippet.utils.archive_file.TarExtractor, tfsnippet.utils.type_utils.TensorArgValidator, tfsnippet.utils.tensor_spec.TensorSpec, tfsnippet.utils.tensor_wrapper.TensorWrapper, tfsnippet.utils.reuse.VarScopeObject, tfsnippet.utils.random.VarScopeRandomState, tfsnippet.utils.archive_file.ZipExtractor, tfsnippet.utils.deprecation.deprecated

Indices and tables