import torch class RBM(): def __init__(self, num_visible, num_hidden, k, learning_rate=1e-3, momentum_coefficient=0.5, weight_decay=1e-4, use_cuda=True): self.num_visible = num_visible self.num_hidden = num_hidden self.k = k self.learning_rate = learning_rate self.momentum_coefficient = momentum_coefficient self.weight_decay = weight_decay self.use_cuda = use_cuda self.weights = torch.randn(num_visible, num_hidden) * 0.1 self.visible_bias = torch.ones(num_visible) * 0.5 self.hidden_bias = torch.zeros(num_hidden) self.weights_momentum = torch.zeros(num_visible, num_hidden) self.visible_bias_momentum = torch.zeros(num_visible) self.hidden_bias_momentum = torch.zeros(num_hidden) if self.use_cuda: self.weights = self.weights.cuda() self.visible_bias = self.visible_bias.cuda() self.hidden_bias = self.hidden_bias.cuda() self.weights_momentum = self.weights_momentum.cuda() self.visible_bias_momentum = self.visible_bias_momentum.cuda() self.hidden_bias_momentum = self.hidden_bias_momentum.cuda() def sample_hidden(self, visible_probabilities): hidden_activations = torch.matmul(visible_probabilities, self.weights) + self.hidden_bias hidden_probabilities = self._sigmoid(hidden_activations) return hidden_probabilities def sample_visible(self, hidden_probabilities): visible_activations = torch.matmul(hidden_probabilities, self.weights.t()) + self.visible_bias visible_probabilities = self._sigmoid(visible_activations) return visible_probabilities def contrastive_divergence(self, input_data): # Positive phase positive_hidden_probabilities = self.sample_hidden(input_data) positive_hidden_activations = (positive_hidden_probabilities >= self._random_probabilities(self.num_hidden)).float() positive_associations = torch.matmul(input_data.t(), positive_hidden_activations) # Negative phase hidden_activations = positive_hidden_activations for step in range(self.k): visible_probabilities = self.sample_visible(hidden_activations) hidden_probabilities = self.sample_hidden(visible_probabilities) hidden_activations = (hidden_probabilities >= self._random_probabilities(self.num_hidden)).float() negative_visible_probabilities = visible_probabilities negative_hidden_probabilities = hidden_probabilities negative_associations = torch.matmul(negative_visible_probabilities.t(), negative_hidden_probabilities) # Update parameters self.weights_momentum *= self.momentum_coefficient self.weights_momentum += (positive_associations - negative_associations) self.visible_bias_momentum *= self.momentum_coefficient self.visible_bias_momentum += torch.sum(input_data - negative_visible_probabilities, dim=0) self.hidden_bias_momentum *= self.momentum_coefficient self.hidden_bias_momentum += torch.sum(positive_hidden_probabilities - negative_hidden_probabilities, dim=0) batch_size = input_data.size(0) self.weights += self.weights_momentum * self.learning_rate / batch_size self.visible_bias += self.visible_bias_momentum * self.learning_rate / batch_size self.hidden_bias += self.hidden_bias_momentum * self.learning_rate / batch_size self.weights -= self.weights * self.weight_decay # L2 weight decay # Compute reconstruction error error = torch.sum((input_data - negative_visible_probabilities)**2) return error def _sigmoid(self, x): return 1 / (1 + torch.exp(-x)) def _random_probabilities(self, num): random_probabilities = torch.rand(num) if self.use_cuda: random_probabilities = random_probabilities.cuda() return random_probabilities