Cara menggunakan TORCH.RANDINT pada Python
Returns a tensor filled with random integers generated uniformly between The shape of the tensor is defined by the variable argument Note With the global dtype default ( low (int, optional) – Lowest integer to be drawn from the distribution. Default: 0. high
(int) – One above the highest integer to be drawn from the distribution. size (tuple) – a tuple defining the shape of the output tensor. generator
( out (Tensor, optional) – the output tensor. dtype (torch.dtype, optional) – if
layout ( device ( requires_grad (bool, optional) – If autograd should record operations on the returned tensor. Default: Example: The following are 30 code examples of torch.randint(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch, or try the search function .Example #1 def get_batch(source, i, train): if train: i = torch.randint(low=0, high=(len(source) - args.bptt), size=(1,)).long().item() seq_len = args.bptt target = source[i + 1:i + 1 + seq_len].t() else: seq_len = min(args.bptt, len(source) - 1 - i) target = source[i + seq_len, :] data = source[i:i + seq_len].t() data_mask = (data != pad).unsqueeze(-2) target_mask = make_std_mask(data.long()) # reshape target to match what cross_entropy expects target = target.contiguous().view(-1) return data, target, data_mask, target_mask Example #2 def __init__(self, thresh=1e-8, projDim=8192, input_dim=512): super(CBP, self).__init__() self.thresh = thresh self.projDim = projDim self.input_dim = input_dim self.output_dim = projDim torch.manual_seed(1) self.h_ = [ torch.randint(0, self.output_dim, (self.input_dim,),dtype=torch.long), torch.randint(0, self.output_dim, (self.input_dim,),dtype=torch.long) ] self.weights_ = [ (2 * torch.randint(0, 2, (self.input_dim,)) - 1).float(), (2 * torch.randint(0, 2, (self.input_dim,)) - 1).float() ] indices1 = torch.cat((torch.arange(input_dim, dtype=torch.long).reshape(1, -1), self.h_[0].reshape(1, -1)), dim=0) indices2 = torch.cat((torch.arange(input_dim, dtype=torch.long).reshape(1, -1), self.h_[1].reshape(1, -1)), dim=0) self.sparseM = [ torch.sparse.FloatTensor(indices1, self.weights_[0], torch.Size([self.input_dim, self.output_dim])).to_dense(), torch.sparse.FloatTensor(indices2, self.weights_[1], torch.Size([self.input_dim, self.output_dim])).to_dense(), ] Example #3 def test_adam_poincare(params): torch.manual_seed(44) manifold = geoopt.PoincareBall() ideal = manifold.random(10, 2) start = manifold.random(10, 2) start = geoopt.ManifoldParameter(start, manifold=manifold) def closure(): idx = torch.randint(10, size=(3,)) start_select = torch.nn.functional.embedding(idx, start, sparse=True) ideal_select = torch.nn.functional.embedding(idx, ideal, sparse=True) optim.zero_grad() loss = manifold.dist2(start_select, ideal_select).sum() loss.backward() assert start.grad.is_sparse return loss.item() optim = geoopt.optim.SparseRiemannianSGD([start], **params) for _ in range(2000): optim.step(closure) np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5) Example #4 def test_adam_poincare(params): torch.manual_seed(44) manifold = geoopt.PoincareBall() ideal = manifold.random(10, 2) start = manifold.random(10, 2) start = geoopt.ManifoldParameter(start, manifold=manifold) def closure(): idx = torch.randint(10, size=(3,)) start_select = torch.nn.functional.embedding(idx, start, sparse=True) ideal_select = torch.nn.functional.embedding(idx, ideal, sparse=True) optim.zero_grad() loss = manifold.dist2(start_select, ideal_select).sum() loss.backward() assert start.grad.is_sparse return loss.item() optim = geoopt.optim.SparseRiemannianAdam([start], **params) for _ in range(2000): optim.step(closure) np.testing.assert_allclose(start.data, ideal, atol=1e-5, rtol=1e-5) Example #5 def neg_sample(self, batch): batch = batch.repeat(self.walks_per_node * self.num_negative_samples) rws = [batch] for i in range(self.walk_length): keys = self.metapath[i % len(self.metapath)] batch = torch.randint(0, self.num_nodes_dict[keys[-1]], (batch.size(0), ), dtype=torch.long) rws.append(batch) rw = torch.stack(rws, dim=-1) rw.add_(self.offset.view(1, -1)) walks = [] num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size for j in range(num_walks_per_rw): walks.append(rw[:, j:j + self.context_size]) return torch.cat(walks, dim=0) Example #6 def test_gnn_explainer(): model = Net() explainer = GNNExplainer(model, log=False) assert explainer.__repr__() == 'GNNExplainer()' x = torch.randn(8, 3) edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]]) y = torch.randint(0, 6, (8, ), dtype=torch.long) node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index) assert node_feat_mask.size() == (x.size(1), ) assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1 assert edge_mask.size() == (edge_index.size(1), ) assert edge_mask.min() >= 0 and edge_mask.max() <= 1 explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=None) explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=0.5) explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=None) explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=0.5) Example #7 def test_deep_graph_infomax(): def corruption(z): return z + 1 model = DeepGraphInfomax( hidden_channels=16, encoder=lambda x: x, summary=lambda z, *args: z.mean(dim=0), corruption=lambda x: x + 1) assert model.__repr__() == 'DeepGraphInfomax(16)' x = torch.ones(20, 16) pos_z, neg_z, summary = model(x) assert pos_z.size() == (20, 16) and neg_z.size() == (20, 16) assert summary.size() == (16, ) loss = model.loss(pos_z, neg_z, summary) assert 0 <= loss.item() acc = model.test( torch.ones(20, 16), torch.randint(10, (20, )), torch.ones(20, 16), torch.randint(10, (20, ))) assert 0 <= acc and acc <= 1 Example #8 def test_node2vec(): edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) model = Node2Vec(edge_index, embedding_dim=16, walk_length=2, context_size=2) assert model.__repr__() == 'Node2Vec(3, 16)' z = model(torch.arange(3)) assert z.size() == (3, 16) pos_rw, neg_rw = model.sample(torch.arange(3)) loss = model.loss(pos_rw, neg_rw) assert 0 <= loss.item() acc = model.test(torch.ones(20, 16), torch.randint(10, (20, )), torch.ones(20, 16), torch.randint(10, (20, ))) assert 0 <= acc and acc <= 1 Example #9 def farthest_point_sample(xyz, npoint): """ Input: xyz: pointcloud data, [B, N, 3] npoint: number of samples Return: centroids: sampled pointcloud index, [B, npoint] """ device = xyz.device B, N, C = xyz.shape centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) distance = torch.ones(B, N).to(device) * 1e10 farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) batch_indices = torch.arange(B, dtype=torch.long).to(device) for i in range(npoint): centroids[:, i] = farthest centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) dist = torch.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = torch.max(distance, -1)[1] return centroids Example #10 def forward_attr(self, e, mode='left'): assert mode == 'left' or mode == 'right' e_emb = self.emb_e(e.view(-1)) # Sample one numerical literal for each entity e_attr = self.numerical_literals[e.view(-1)] m = len(e_attr) idxs = torch.randint(self.n_num_lit, size=(m,)).cuda() attr_emb = self.emb_attr(idxs) inputs = torch.cat([e_emb, attr_emb], dim=1) pred = self.attr_net_left(inputs) if mode == 'left' else self.attr_net_right(inputs) target = e_attr[range(m), idxs] return pred, target Example #11 def forward(self, pos): r"""Memory allocation and sampling Parameters ---------- pos : tensor The positional tensor of shape (B, N, C) Returns ------- tensor of shape (B, self.npoints) The sampled indices in each batch. """ device = pos.device B, N, C = pos.shape pos = pos.reshape(-1, C) dist = th.zeros((B * N), dtype=pos.dtype, device=device) start_idx = th.randint(0, N - 1, (B, ), dtype=th.long, device=device) result = th.zeros((self.npoints * B), dtype=th.long, device=device) farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result) return result.reshape(B, self.npoints) Example #12 def add_insertion_noise(self, tokens, p): if p == 0.0: return tokens num_tokens = len(tokens) n = int(math.ceil(num_tokens * p)) noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1 noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool) noise_mask[noise_indices] = 1 result = torch.LongTensor(n + len(tokens)).fill_(-1) num_random = int(math.ceil(n * self.random_ratio)) result[noise_indices[num_random:]] = self.mask_idx result[noise_indices[:num_random]] = torch.randint(low=1, high=len(self.vocab), size=(num_random,)) result[~noise_mask] = tokens assert (result >= 0).all() return result Example #13 def test_cutmix(self): random_image = torch.rand(5, 3, 100, 100) state = {torchbearer.X: random_image, torchbearer.Y_TRUE: torch.randint(10, (5,)).long(), torchbearer.DEVICE: 'cpu'} torch.manual_seed(7) co = CutMix(0.25, classes=10) co.on_sample(state) reg_img = state[torchbearer.X].view(-1) x = [72, 83, 18, 96, 40] y = [8, 17, 62, 30, 66] perm = [0, 4, 3, 2, 1] sz = 3 rnd = random_image.clone().numpy() known_cut = random_image.clone().numpy() known_cut[0, :, y[0]-sz//2:y[0]+sz//2, x[0]-sz//2:x[0]+sz//2] = rnd[perm[0], :, y[0]-sz//2:y[0]+sz//2, x[0]-sz//2:x[0]+sz//2] known_cut[1, :, y[1]-sz//2:y[1]+sz//2, x[1]-sz//2:x[1]+sz//2] = rnd[perm[1], :, y[1]-sz//2:y[1]+sz//2, x[1]-sz//2:x[1]+sz//2] known_cut[2, :, y[2]-sz//2:y[2]+sz//2, x[2]-sz//2:x[2]+sz//2] = rnd[perm[2], :, y[2]-sz//2:y[2]+sz//2, x[2]-sz//2:x[2]+sz//2] known_cut[3, :, y[3]-sz//2:y[3]+sz//2, x[3]-sz//2:x[3]+sz//2] = rnd[perm[3], :, y[3]-sz//2:y[3]+sz//2, x[3]-sz//2:x[3]+sz//2] known_cut[4, :, y[4]-sz//2:y[4]+sz//2, x[4]-sz//2:x[4]+sz//2] = rnd[perm[4], :, y[4]-sz//2:y[4]+sz//2, x[4]-sz//2:x[4]+sz//2] known_cut = torch.from_numpy(known_cut) known_cut = known_cut.view(-1) diff = (torch.abs(known_cut-reg_img) > 1e-4).any() self.assertTrue(diff.item() == 0) Example #14 def example_mdpooling(): input = torch.randn(2, 32, 64, 64).cuda() input.requires_grad = True batch_inds = torch.randint(2, (20, 1)).cuda().float() x = torch.randint(256, (20, 1)).cuda().float() y = torch.randint(256, (20, 1)).cuda().float() w = torch.randint(64, (20, 1)).cuda().float() h = torch.randint(64, (20, 1)).cuda().float() rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) # mdformable pooling (V2) dpooling = DCNPooling(spatial_scale=1.0 / 4, pooled_size=7, output_dim=32, no_trans=False, group_size=1, trans_std=0.1, deform_fc_dim=1024).cuda() dout = dpooling(input, rois) target = dout.new(*dout.size()) target.data.uniform_(-0.1, 0.1) error = (target - dout).mean() error.backward() print(dout.shape) Example #15 def build_fss_keys(self, type_op): """ The builder to generate functional keys for Function Secret Sharing (FSS) """ if type_op == "eq": fss_class = sy.frameworks.torch.mpc.fss.DPF elif type_op == "comp": fss_class = sy.frameworks.torch.mpc.fss.DIF else: raise ValueError(f"type_op {type_op} not valid") n = sy.frameworks.torch.mpc.fss.n def build_separate_fss_keys(n_party, n_instances=100): assert ( n_party == 2 ), f"The FSS protocol only works for 2 workers, {n_party} were provided." alpha, s_00, s_01, *CW = fss_class.keygen(n_values=n_instances) # simulate sharing TODO clean this mask = th.randint(0, 2 ** n, alpha.shape) return [((alpha - mask) % 2 ** n, s_00, *CW), (mask, s_01, *CW)] return build_separate_fss_keys Example #16 def test_encrypt_decrypt(workers): bob, alice, james = (workers["bob"], workers["alice"], workers["james"]) x = torch.randint(10, (1, 5), dtype=torch.float32) x_encrypted = x.encrypt(workers=[bob, alice], crypto_provider=james, base=10) x_decrypted = x_encrypted.decrypt() assert torch.all(torch.eq(x_decrypted, x)) x = torch.randint(10, (1, 5), dtype=torch.float32) x_encrypted = x.encrypt(workers=[bob, alice], crypto_provider=james) x_decrypted = x_encrypted.decrypt() assert torch.all(torch.eq(x_decrypted, x)) x = torch.randint(10, (1, 5), dtype=torch.float32) public, private = syft.frameworks.torch.he.paillier.keygen() x_encrypted = x.encrypt(protocol="paillier", public_key=public) x_decrypted = x_encrypted.decrypt(protocol="paillier", private_key=private) assert torch.all(torch.eq(x_decrypted, x)) Example #17 def test_save_load(self): bert_save_test = 'roberta_save_test' try: os.makedirs(bert_save_test, exist_ok=True) vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) embed = RobertaEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_roberta', word_dropout=0.1, auto_truncate=True) embed.save(bert_save_test) load_embed = RobertaEmbedding.load(bert_save_test) words = torch.randint(len(vocab), size=(2, 20)) embed.eval(), load_embed.eval() self.assertEqual((embed(words) - load_embed(words)).sum(), 0) finally: import shutil shutil.rmtree(bert_save_test) Example #18 def test_save_load(self): bert_save_test = 'bert_save_test' try: os.makedirs(bert_save_test, exist_ok=True) vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, auto_truncate=True) embed.save(bert_save_test) load_embed = BertEmbedding.load(bert_save_test) words = torch.randint(len(vocab), size=(2, 20)) embed.eval(), load_embed.eval() self.assertEqual((embed(words) - load_embed(words)).sum(), 0) finally: import shutil shutil.rmtree(bert_save_test) Example #19 def test_lsloss(): pred = torch.rand(3, 10) label = torch.randint(0, 10, size=(3,)) Loss = LabelSmoothingLoss(10, 0.1) Loss1 = nn.CrossEntropyLoss() cost = Loss(pred, label) cost1 = Loss1(pred, label) assert cost.shape == cost1.shape Example #20 def test_logits_loss(): pred = torch.rand(3, 10) label = torch.randint(0, 10, size=(3,)) weight = class_balanced_weight(0.9999, np.random.randint(0, 100, size=(10,)).tolist()) Loss = SigmoidCrossEntropy(classes=10, weight=weight) Loss1 = FocalLoss(classes=10, weight=weight, gamma=0.5) Loss2 = ArcLoss(classes=10, weight=weight) cost = Loss(pred, label) cost1 = Loss1(pred, label) cost2 = Loss2(pred, label) print(cost, cost1, cost2) Example #21 def sample_p_0(device, replay_buffer, bs, y=None): if len(replay_buffer) == 0: return init_random(bs), [] buffer_size = len(replay_buffer) if y is None else len(replay_buffer) // n_classes inds = t.randint(0, buffer_size, (bs,)) # if cond, convert inds to class conditional inds if y is not None: inds = y.cpu() * buffer_size + inds assert not args.uncond, "Can't drawn conditional samples without giving me y" buffer_samples = replay_buffer[inds] random_samples = init_random(bs) choose_random = (t.rand(bs) < args.reinit_freq).float()[:, None, None, None] samples = choose_random * random_samples + (1 - choose_random) * buffer_samples return samples.to(device), inds Example #22 def get_sample_q(args, device): def sample_p_0(replay_buffer, bs, y=None): if len(replay_buffer) == 0: return init_random(args, bs), [] buffer_size = len(replay_buffer) if y is None else len(replay_buffer) // args.n_classes inds = t.randint(0, buffer_size, (bs,)) # if cond, convert inds to class conditional inds if y is not None: inds = y.cpu() * buffer_size + inds assert not args.uncond, "Can't drawn conditional samples without giving me y" buffer_samples = replay_buffer[inds] random_samples = init_random(args, bs) choose_random = (t.rand(bs) < args.reinit_freq).float()[:, None, None, None] samples = choose_random * random_samples + (1 - choose_random) * buffer_samples return samples.to(device), inds def sample_q(f, replay_buffer, y=None, n_steps=args.n_steps): """this func takes in replay_buffer now so we have the option to sample from scratch (i.e. replay_buffer==[]). See test_wrn_ebm.py for example. """ f.eval() # get batch size bs = args.batch_size if y is None else y.size(0) # generate initial samples and buffer inds of those samples (if buffer is used) init_sample, buffer_inds = sample_p_0(replay_buffer, bs=bs, y=y) x_k = t.autograd.Variable(init_sample, requires_grad=True) # sgld for k in range(n_steps): f_prime = t.autograd.grad(f(x_k, y=y).sum(), [x_k], retain_graph=True)[0] x_k.data += args.sgld_lr * f_prime + args.sgld_std * t.randn_like(x_k) f.train() final_samples = x_k.detach() # update replay buffer if len(replay_buffer) > 0: replay_buffer[buffer_inds] = final_samples.cpu() return final_samples return sample_q Example #23 def forward(self, xyz, features, end_points, mode=''): """ Args: xyz: (B,K,3) features: (B,C,K) Returns: scores: (B,num_proposal,2+3+NH*2+NS*4) """ if self.sampling == 'vote_fps': # Farthest point sampling (FPS) on votes original_feature = features xyz, features, fps_inds = self.vote_aggregation(xyz, features) #original_feature = torch.gather(original_features, 2, fps_inds.unsqueeze(1).repeat(1,256,1).detach().long()).contiguous() sample_inds = fps_inds elif self.sampling == 'seed_fps': # FPS on seed and choose the votes corresponding to the seeds # This gets us a slightly better coverage of *object* votes than vote_fps (which tends to get more cluster votes) sample_inds = pointnet2_utils.furthest_point_sample(end_points['seed_xyz'], self.num_proposal) xyz, features, _ = self.vote_aggregation(xyz, features, sample_inds) elif self.sampling == 'random': # Random sampling from the votes num_seed = end_points['seed_xyz'].shape[1] sample_inds = torch.randint(0, num_seed, (batch_size, self.num_proposal), dtype=torch.int).cuda() xyz, features, _ = self.vote_aggregation(xyz, features, sample_inds) else: log_string('Unknown sampling strategy: %s. Exiting!'%(self.sampling)) exit() end_points['aggregated_vote_xyz'+mode] = xyz # (batch_size, num_proposal, 3) end_points['aggregated_vote_inds'+mode] = sample_inds # (batch_size, num_proposal,) # should be 0,1,2,...,num_proposal end_points['aggregated_feature'+mode] = features # --------- PROPOSAL GENERATION --------- net = F.relu(self.bn1(self.conv1(features))) last_net = F.relu(self.bn2(self.conv2(net))) net = self.conv3(last_net) # (batch_size, 2+3+num_heading_bin*2+num_size_cluster*4, num_proposal) newcenter, end_points = decode_scores(net, end_points, self.num_class, mode=mode) return newcenter.contiguous(), features.contiguous(), end_points Example #24 def getitem(self, index): if self.split=='train': index = int(torch.randint(self.nVideos,())) img = self.LoadImage(index) pts, c, s = self.GetPartInfo(index) r = 0 if self.split == 'train': s = s * (2 ** Rnd(ref.scale)) r = 0 if np.random.random() < 0.6 else Rnd(ref.rotate) inp = Crop(img, c, s, r, ref.inputRes) / 256. out = np.zeros((ref.nJoints, ref.outputRes, ref.outputRes)) Reg = np.zeros((ref.nJoints, 3)) for i in range(ref.nJoints): if pts[i][0] > 1: pt = Transform(pts[i], c, s, r, ref.outputRes) out[i] = DrawGaussian(out[i], pt, ref.hmGauss) Reg[i, :2] = pt Reg[i, 2] = 1 if self.split == 'train': if np.random.random() < 0.5: inp = Flip(inp) out = ShuffleLR(Flip(out)) Reg[:, 1] = Reg[:, 1] * -1 Reg = ShuffleLR(Reg) #print 'before', inp[0].max(), inp[0].mean() inp[0] = np.clip(inp[0] * (np.random.random() * (0.4) + 0.6), 0, 1) inp[1] = np.clip(inp[1] * (np.random.random() * (0.4) + 0.6), 0, 1) inp[2] = np.clip(inp[2] * (np.random.random() * (0.4) + 0.6), 0, 1) #print 'after', inp[0].max(), inp[0].mean() meta = (np.zeros((ref.nJoints, 3))) if self.returnMeta: return inp, out, Reg, meta else: return inp, out Example #25 def _create_data(self): x1 = torch.rand(self.num_points) * 4 - 2 x2_ = torch.rand(self.num_points) - torch.randint(0, 2, [self.num_points]).float() * 2 x2 = x2_ + torch.floor(x1) % 2 self.data = torch.stack([x1, x2]).t() * 2 Example #26 def _get_mask_and_degrees(cls, in_degrees, out_features, autoregressive_features, random_mask, is_output): if is_output: out_degrees = utils.tile( _get_input_degrees(autoregressive_features), out_features // autoregressive_features ) mask = (out_degrees[..., None] > in_degrees).float() else: if random_mask: min_in_degree = torch.min(in_degrees).item() min_in_degree = min(min_in_degree, autoregressive_features - 1) out_degrees = torch.randint( low=min_in_degree, high=autoregressive_features, size=[out_features], dtype=torch.long) else: max_ = max(1, autoregressive_features - 1) min_ = min(1, autoregressive_features - 1) out_degrees = torch.arange(out_features) % max_ + min_ mask = (out_degrees[..., None] >= in_degrees).float() return mask, out_degrees Example #27 def sample_rademacher_like(y): return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1 # -------------- Helper functions -------------- Example #28 def __sample_nodes__(self, batch_size): edge_sample = torch.randint(0, self.E, (batch_size, self.batch_size), dtype=torch.long) return self.adj_t.storage.row()[edge_sample] Example #29 def __sample_nodes__(self, batch_size): start = torch.randint(0, self.N, (batch_size, ), dtype=torch.long) node_idx = self.adj_t.random_walk(start.flatten(), self.walk_length) return node_idx.view(-1) Example #30 def structured_negative_sampling(edge_index, num_nodes=None): r"""Samples a negative edge :obj:`(i,k)` for every positive edge :obj:`(i,j)` in the graph given by :attr:`edge_index`, and returns it as a tuple of the form :obj:`(i,j,k)`. Args: edge_index (LongTensor): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (LongTensor, LongTensor, LongTensor) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) i, j = edge_index.to('cpu') idx_1 = i * num_nodes + j k = torch.randint(num_nodes, (i.size(0), ), dtype=torch.long) idx_2 = i * num_nodes + k mask = torch.from_numpy(np.isin(idx_2, idx_1)).to(torch.bool) rest = mask.nonzero().view(-1) while rest.numel() > 0: # pragma: no cover tmp = torch.randint(num_nodes, (rest.numel(), ), dtype=torch.long) idx_2 = i[rest] * num_nodes + tmp mask = torch.from_numpy(np.isin(idx_2, idx_1)).to(torch.bool) k[rest] = tmp rest = rest[mask.nonzero().view(-1)] return edge_index[0], edge_index[1], k.to(edge_index.device) |