importtorchfromtorch.distributionsimportconstraintsfromtorch.distributions.distributionimportDistributionfromtorch.distributions.utilsimport(broadcast_all,lazy_property,logits_to_probs,probs_to_logits,)__all__=["Binomial"]def_clamp_by_zero(x):# works like clamp(x, min=0) but has grad at 0 is 0.5return(x.clamp(min=0)+x-x.clamp(max=0))/2
[docs]classBinomial(Distribution):r""" Creates a Binomial distribution parameterized by :attr:`total_count` and either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be broadcastable with :attr:`probs`/:attr:`logits`. Example:: >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> m = Binomial(100, torch.tensor([0 , .2, .8, 1])) >>> x = m.sample() tensor([ 0., 22., 71., 100.]) >>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8])) >>> x = m.sample() tensor([[ 4., 5.], [ 7., 6.]]) Args: total_count (int or Tensor): number of Bernoulli trials probs (Tensor): Event probabilities logits (Tensor): Event log-odds """arg_constraints={"total_count":constraints.nonnegative_integer,"probs":constraints.unit_interval,"logits":constraints.real,}has_enumerate_support=Truedef__init__(self,total_count=1,probs=None,logits=None,validate_args=None):if(probsisNone)==(logitsisNone):raiseValueError("Either `probs` or `logits` must be specified, but not both.")ifprobsisnotNone:(self.total_count,self.probs,)=broadcast_all(total_count,probs)self.total_count=self.total_count.type_as(self.probs)else:(self.total_count,self.logits,)=broadcast_all(total_count,logits)self.total_count=self.total_count.type_as(self.logits)self._param=self.probsifprobsisnotNoneelseself.logitsbatch_shape=self._param.size()super().__init__(batch_shape,validate_args=validate_args)
[docs]deflog_prob(self,value):ifself._validate_args:self._validate_sample(value)log_factorial_n=torch.lgamma(self.total_count+1)log_factorial_k=torch.lgamma(value+1)log_factorial_nmk=torch.lgamma(self.total_count-value+1)# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)# (case logit < 0) = k * logit - n * log1p(e^logit)# (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p)# = k * logit - n * logit - n * log1p(e^-logit)# (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)normalize_term=(self.total_count*_clamp_by_zero(self.logits)+self.total_count*torch.log1p(torch.exp(-torch.abs(self.logits)))-log_factorial_n)return(value*self.logits-log_factorial_k-log_factorial_nmk-normalize_term)
[docs]defentropy(self):total_count=int(self.total_count.max())ifnotself.total_count.min()==total_count:raiseNotImplementedError("Inhomogeneous total count not supported by `entropy`.")log_prob=self.log_prob(self.enumerate_support(False))return-(torch.exp(log_prob)*log_prob).sum(0)
[docs]defenumerate_support(self,expand=True):total_count=int(self.total_count.max())ifnotself.total_count.min()==total_count:raiseNotImplementedError("Inhomogeneous total count not supported by `enumerate_support`.")values=torch.arange(1+total_count,dtype=self._param.dtype,device=self._param.device)values=values.view((-1,)+(1,)*len(self._batch_shape))ifexpand:values=values.expand((-1,)+self._batch_shape)returnvalues
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.